diff --git a/components/component-executor.py b/components/component-executor.py index 02d08cb1..2dd02813 100644 --- a/components/component-executor.py +++ b/components/component-executor.py @@ -26,25 +26,25 @@ # limitations under the License. # ############################################################################# -import base64 +from __future__ import annotations + import collections -import errno -import glob import http.client import json import os import shlex import signal import socket -import ssl import string import subprocess import sys import time -import traceback -import urllib.error -import urllib.request +from pathlib import Path +from typing import Any, Dict, NamedTuple, Optional, Tuple + +import component_registration +Descriptor = Dict[str, Any] def main(): executor_proc, tail_proc = init() @@ -56,63 +56,101 @@ def main(): sys.exit(exit_code) -def init(): - # Optional configurable environment variables - wfm_user = os.getenv('WFM_USER', 'admin') - wfm_password = os.getenv('WFM_PASSWORD', 'mpfadm') - wfm_base_url = os.getenv('WFM_BASE_URL', 'http://workflow-manager:8080/workflow-manager') - activemq_host = os.getenv('ACTIVE_MQ_HOST', 'activemq') - component_log_name = os.getenv('COMPONENT_LOG_NAME') - disable_component_registration = os.getenv('DISABLE_COMPONENT_REGISTRATION') - node_name = os.getenv('THIS_MPF_NODE') - - # Environment variables from base Docker image - mpf_home = os.getenv('MPF_HOME', '/opt/mpf') - base_log_path = os.getenv('MPF_LOG_PATH', os.path.join(mpf_home, 'share/logs')) - - - descriptor_path = find_descriptor(mpf_home) +class EnvConfig(NamedTuple): + wfm_user: str + wfm_password: str + wfm_base_url: str + oidc_issuer_uri: Optional[str] + oidc_client_id: Optional[str] + oidc_client_secret: Optional[str] + activemq_host: str + component_log_name: Optional[str] + disable_component_registration: bool + node_name: Optional[str] + mpf_home: Path + base_log_path: Path + + @staticmethod + def create(): + oidc_issuer_uri = os.getenv('OIDC_JWT_ISSUER_URI', os.getenv('OIDC_ISSUER_URI')) + oidc_client_id = os.getenv('OIDC_CLIENT_ID') + oidc_client_secret = os.getenv('OIDC_CLIENT_SECRET') + if oidc_issuer_uri and (not oidc_client_id or not oidc_client_secret): + raise RuntimeError( + 'The OIDC_CLIENT_ID and OIDC_CLIENT_SECRET environment variables must both ' + 'be set.') + + mpf_home = Path(os.getenv('MPF_HOME', '/opt/mpf')) + if log_path_str := os.getenv('MPF_LOG_PATH'): + log_path = Path(log_path_str) + else: + log_path = mpf_home / 'share/logs' + return EnvConfig( + os.getenv('WFM_USER', 'admin'), + os.getenv('WFM_PASSWORD', 'mpfadm'), + os.getenv('WFM_BASE_URL', 'http://workflow-manager:8080/workflow-manager'), + oidc_issuer_uri, + oidc_client_id, + oidc_client_secret, + os.getenv('ACTIVE_MQ_HOST', 'activemq'), + os.getenv('COMPONENT_LOG_NAME'), + bool(os.getenv('DISABLE_COMPONENT_REGISTRATION')), + os.getenv('THIS_MPF_NODE'), + mpf_home, + log_path) + + +def init() -> Tuple[subprocess.Popen[str], Optional[subprocess.Popen[bytes]]]: + env_config = EnvConfig.create() + + descriptor_path = find_descriptor(env_config.mpf_home) print('Loading descriptor from', descriptor_path) with open(descriptor_path, 'rb') as descriptor_file: unparsed_descriptor = descriptor_file.read() - if disable_component_registration: + if env_config.disable_component_registration: print('Component registration disabled because the ' '"DISABLE_COMPONENT_REGISTRATION" environment variable was set.') else: - register_component(unparsed_descriptor, wfm_base_url, wfm_user, wfm_password) + component_registration.register_component(env_config, unparsed_descriptor) - wait_for_activemq(activemq_host) + wait_for_activemq(env_config.activemq_host) descriptor = json.loads(unparsed_descriptor) - if not node_name: + if env_config.node_name: + node_name = env_config.node_name + else: component_name = descriptor['componentName'] - node_name = '{}_id_{}'.format(component_name, os.getenv('HOSTNAME')) - log_dir = os.path.join(base_log_path, node_name, 'log') + node_name = f'{component_name}_id_{os.getenv("HOSTNAME")}' + log_dir = env_config.base_log_path / node_name / 'log' - executor_proc = start_executor(descriptor, mpf_home, activemq_host, node_name) - tail_proc = tail_log_if_needed(log_dir, component_log_name, executor_proc.pid) + executor_proc = start_executor( + descriptor, env_config.mpf_home, env_config.activemq_host, node_name) + tail_proc = tail_log_if_needed(log_dir, env_config.component_log_name, executor_proc.pid) return executor_proc, tail_proc -def find_descriptor(mpf_home): - glob_pattern = os.path.join(mpf_home, 'plugins/*/descriptor/descriptor.json') - glob_matches = glob.glob(glob_pattern) +def find_descriptor(mpf_home: Path) -> Path: + glob_subpath = 'plugins/*/descriptor/descriptor.json' + glob_matches = list(mpf_home.glob(glob_subpath)) if len(glob_matches) == 1: return glob_matches[0] + + glob_pattern = str(mpf_home / glob_subpath) if len(glob_matches) == 0: - raise RuntimeError('Expecting to find a descriptor file at "{}", but it was not there.' - .format(glob_pattern)) + raise RuntimeError( + f'Expecting to find a descriptor file at "{glob_pattern}", but it was not there.') - if all(os.path.samefile(glob_matches[0], m) for m in glob_matches[1:]): + if all(glob_matches[0].samefile(m) for m in glob_matches[1:]): return glob_matches[0] - raise RuntimeError('Expected to find one descriptor matching "{}", but the following descriptors were found: {}' - .format(glob_pattern, glob_matches)) + raise RuntimeError( + f'Expected to find one descriptor matching "{glob_pattern}", but the following ' + f'descriptors were found: {glob_matches}') -def wait_for_activemq(activemq_host): +def wait_for_activemq(activemq_host: str) -> None: while True: try: conn = http.client.HTTPConnection(activemq_host, 8161) @@ -120,155 +158,55 @@ def wait_for_activemq(activemq_host): resp = conn.getresponse() if 200 <= resp.status <= 299 or resp.status == 401: return - print('Received non-success status code of {} when trying to connect to ActiveMQ. ' - 'This is either because ActiveMQ is still starting or the wrong host name was used for the ' - 'ACTIVE_MQ_HOST(={}) environment variable. Connection to ActiveMQ will re-attempted in 10 seconds.' - .format(resp.status, activemq_host)) + print(f'Received non-success status code of {resp.status} when trying to connect to ' + 'ActiveMQ. This is either because ActiveMQ is still starting or the wrong host ' + f'name was used for the ACTIVE_MQ_HOST(={activemq_host}) environment variable. ' + 'Connection to ActiveMQ will re-attempted in 10 seconds.') except socket.error as e: - print('Attempt to connect to ActiveMQ failed due to "{}". This is either because ActiveMQ is still ' - 'starting or the wrong host name was used for the ACTIVE_MQ_HOST(={}) environment variable. ' - 'Connection to ActiveMQ will re-attempted in 10 seconds.'.format(e, activemq_host)) + print(f'Attempt to connect to ActiveMQ failed due to "{e}". This is either because ' + 'ActiveMQ is still starting or the wrong host name was used for the ' + f'ACTIVE_MQ_HOST(={activemq_host}) environment variable. Connection to ActiveMQ ' + 'will re-attempted in 10 seconds.') time.sleep(10) -def register_component(unparsed_descriptor, wfm_base_url, wfm_user, wfm_password): - if not wfm_user or not wfm_password: - raise RuntimeError('The WFM_USER and WFM_PASSWORD environment variables must both be set.') - - auth_info_bytes = (wfm_user + ':' + wfm_password).encode('utf-8') - base64_bytes = base64.b64encode(auth_info_bytes) - headers = { - 'Authorization': 'Basic ' + base64_bytes.decode('utf-8'), - 'Content-Length': len(unparsed_descriptor), - 'Content-Type': 'application/json', - 'Accept': 'application/json' - } - - url = wfm_base_url + '/rest/components/registerUnmanaged' - print('Registering component by posting descriptor to', url) - try: - post_descriptor_with_retry(unparsed_descriptor, url, headers) - except urllib.error.HTTPError as err: - handle_registration_error(err) - - -def post_descriptor_with_retry(unparsed_descriptor, url, headers): - - ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ssl_ctx.verify_mode = ssl.CERT_NONE - opener = urllib.request.build_opener(ThrowingRedirectHandler(), urllib.request.HTTPSHandler(context=ssl_ctx)) - - while True: - try: - post_descriptor(unparsed_descriptor, url, headers, opener) - return - except http.client.BadStatusLine: - new_url = url.replace('http://', 'https://') - print('Initial registration response failed due to an invalid status line in the HTTP response. ' - 'This usually means that the server is using HTTPS, but an "http://" URL was used. ' - 'Trying again with:', new_url) - post_descriptor(unparsed_descriptor, new_url, headers, opener) - return - - except urllib.error.HTTPError as err: - if err.url == url: - raise - else: - # This generally means the provided WFM url used HTTP, but WFM was configured to use HTTPS - print('Initial registration response failed. Trying with redirected url: ', err.url) - post_descriptor(unparsed_descriptor, err.url, headers, opener) - return - - except urllib.error.URLError as err: - reason = err.reason - should_retry_immediately = isinstance(reason, ssl.SSLError) and reason.reason == 'UNKNOWN_PROTOCOL' - if should_retry_immediately: - new_url = url.replace('https://', 'http://') - print('Initial registration response failed due to an "UNKNOWN_PROTOCOL" SSL error. ' - 'This usually means that the server is using HTTP on the specified port, ' - 'but an "https://" URL was used. Trying again with:', new_url) - url = new_url - # continue since the post to new_url might cause a redirect which raises an error that can be - # handled. - continue - should_retry_after_delay = (isinstance(reason, socket.gaierror) and reason.errno == socket.EAI_NONAME - or isinstance(reason, socket.error) and reason.errno == errno.ECONNREFUSED) - if not should_retry_after_delay: - raise - print('Registration failed due to "{}". This is either because the Workflow Manager is still starting or ' - 'because the wrong URL was used for the WFM_BASE_URL(={}) environment variable. Registration will ' - 'be re-attempted in 10 seconds.'.format(reason.strerror, url)) - time.sleep(10) - - -# The default urllib.request.HTTPRedirectHandler converts POST requests to GET requests. -# This subclass just throws an exception so we can post to the new URL ourselves. -class ThrowingRedirectHandler(urllib.request.HTTPRedirectHandler): - def http_error_302(self, req, fp, code, msg, headers): - if 'location' in headers: - new_url = headers['location'] - elif 'uri' in headers: - new_url = headers['uri'] - else: - raise RuntimeError('Received HTTP redirect response with no location header.') - - raise urllib.error.HTTPError(new_url, code, msg, headers, fp) - - -def post_descriptor(unparsed_descriptor, url, headers, opener): - request = urllib.request.Request(url, unparsed_descriptor, headers=headers) - with opener.open(request) as response: - body = response.read() - print('Registration response:', body) - - -def handle_registration_error(http_error): - traceback.print_exc() - print(file=sys.stderr) - - response_content = http_error.read() - try: - server_message = json.loads(response_content)['message'] - except (ValueError, KeyError): - server_message = response_content - - error_msg = 'The following error occurred while trying to register component: {}: {}' \ - .format(http_error, server_message) - if http_error.code == 401: - error_msg += '\nThe WFM_USER and WFM_PASSWORD environment variables need to be changed.' - raise RuntimeError(error_msg) - - -def start_executor(descriptor, mpf_home, activemq_host, node_name): - activemq_broker_uri = 'failover://(tcp://{}:61616)?jms.prefetchPolicy.all=0&startupMaxReconnectAttempts=1'\ - .format(activemq_host) +def start_executor(descriptor: Descriptor, mpf_home: Path, activemq_host: str, node_name: str + ) -> subprocess.Popen[str]: + activemq_broker_uri = ( + f'failover://(tcp://{activemq_host}:61616)' + '?jms.prefetchPolicy.all=0&startupMaxReconnectAttempts=1') algorithm_name = descriptor['algorithm']['name'].upper() - queue_name = 'MPF.DETECTION_{}_REQUEST'.format(algorithm_name) + queue_name = f'MPF.DETECTION_{algorithm_name}_REQUEST' language = descriptor['sourceLanguage'].lower() executor_env = get_executor_env_vars(mpf_home, descriptor, node_name) if language in ('c++', 'python'): - amq_detection_component_path = os.path.join(mpf_home, 'bin/amq_detection_component') + amq_detection_component_path = str(mpf_home / 'bin/amq_detection_component') batch_lib = expand_env_vars(descriptor['batchLibrary'], executor_env) - executor_command = (amq_detection_component_path, activemq_broker_uri, batch_lib, queue_name, language) + executor_command = ( + amq_detection_component_path, activemq_broker_uri, batch_lib, queue_name, language) elif language == 'java': executor_jar = find_java_executor_jar(descriptor, mpf_home) - component_jar = os.path.join(mpf_home, 'plugins', descriptor['componentName'], descriptor['batchLibrary']) - class_path = executor_jar + ':' + component_jar + component_jar = ( + mpf_home / 'plugins' / descriptor['componentName'] / descriptor['batchLibrary']) + class_path = f'{executor_jar}:{component_jar}' executor_command = ('java', '--class-path', class_path, 'org.mitre.mpf.component.executor.detection.MPFDetectionMain', queue_name, activemq_broker_uri) else: - raise RuntimeError('Descriptor contained invalid sourceLanguage property. It must be c++, python, or java.') + raise RuntimeError( + 'Descriptor contained invalid sourceLanguage property. ' + 'It must be c++, python, or java.') print('Starting component executor with command:', shlex.join(executor_command)) - executor_proc = subprocess.Popen(executor_command, - env=executor_env, - cwd=os.path.join(mpf_home, 'plugins', descriptor['componentName']), - stdin=subprocess.PIPE, - text=True) + executor_proc = subprocess.Popen( + executor_command, + env=executor_env, + cwd=mpf_home / 'plugins' / descriptor['componentName'], + stdin=subprocess.PIPE, + text=True) # Handle ctrl-c signal.signal(signal.SIGINT, lambda sig, frame: stop_executor(executor_proc)) @@ -277,24 +215,27 @@ def start_executor(descriptor, mpf_home, activemq_host, node_name): return executor_proc -def find_java_executor_jar(descriptor, mpf_home): - java_executor_path_pattern = os.path.join(mpf_home, 'jars', 'mpf-java-component-executor-{}.jar') +def find_java_executor_jar(descriptor: Descriptor, mpf_home: Path) -> Path: + jars_dir = mpf_home / 'jars' middleware_version = descriptor['middlewareVersion'] - executor_matching_version_path = java_executor_path_pattern.format(middleware_version) - if os.path.exists(executor_matching_version_path): + executor_matching_version_path = ( + jars_dir / f'mpf-java-component-executor-{middleware_version}.jar') + if executor_matching_version_path.exists(): return executor_matching_version_path - executor_path_with_glob = java_executor_path_pattern.format('*') - glob_matches = glob.glob(executor_path_with_glob) + glob_subpath = 'mpf-java-component-executor-*.jar' + glob_matches = list(jars_dir.glob(glob_subpath)) if not glob_matches: - raise RuntimeError('Did not find the OpenMPF Java Executor jar at "{}".'.format(executor_path_with_glob)) + executor_path_with_glob = str(jars_dir / glob_subpath) + raise RuntimeError( + f'Did not find the OpenMPF Java Executor jar at "{executor_path_with_glob}".') expanded_executor_path = glob_matches[0] - print('WARNING: Did not find the OpenMPF Java Executor version "{}" at "{}". Using "{}" instead.' - .format(middleware_version, executor_matching_version_path, expanded_executor_path)) + print(f'WARNING: Did not find the OpenMPF Java Executor version "{middleware_version}" at ' + f'"{executor_matching_version_path}". Using "{expanded_executor_path}" instead.') return expanded_executor_path -def stop_executor(executor_proc): +def stop_executor(executor_proc: subprocess.Popen[str]) -> None: still_running = executor_proc.poll() is None if still_running: print('Sending quit to component executor') @@ -303,23 +244,17 @@ def stop_executor(executor_proc): executor_proc.stdin.flush() -def tail_log_if_needed(log_dir, component_log_name, executor_pid): +def tail_log_if_needed(log_dir: Path, component_log_name: Optional[str], executor_pid: int + ) -> Optional[subprocess.Popen[bytes]]: if not component_log_name: return None - if not os.path.exists(log_dir): - try: - os.makedirs(log_dir) - except OSError as e: - # Two components may both try to create the directory at the same time, - # so we ignore the error indicating that the directory already exists. - if e.errno != errno.EEXIST: - raise - - component_log_full_path = os.path.join(log_dir, component_log_name) - if not os.path.exists(component_log_full_path): + log_dir.mkdir(parents=True, exist_ok=True) + + component_log_full_path = log_dir / component_log_name + if not component_log_full_path.exists(): # Create file if it doesn't exist. - open(component_log_full_path, 'a').close() + component_log_full_path.touch(exist_ok=True) tail_command = ( 'tail', @@ -327,16 +262,16 @@ def tail_log_if_needed(log_dir, component_log_name, executor_pid): '--follow=name', # Watch executor process and exit when executor exists. '--pid', str(executor_pid), - component_log_full_path) + str(component_log_full_path)) print('Displaying logs with command: ', shlex.join(tail_command)) - # Use preexec_fn=os.setpgrp to prevent ctrl-c from killing tail since + # Use start_new_session to prevent ctrl-c from killing tail since # executor may write to log file when shutting down. - return subprocess.Popen(tail_command, preexec_fn=os.setpgrp) + return subprocess.Popen(tail_command, start_new_session=True) -def get_executor_env_vars(mpf_home, descriptor, node_name): +def get_executor_env_vars(mpf_home: Path, descriptor: Descriptor, node_name: str) -> Dict[str, str]: executor_env = os.environ.copy() executor_env['THIS_MPF_NODE'] = node_name executor_env['SERVICE_NAME'] = descriptor['componentName'] @@ -357,12 +292,12 @@ def get_executor_env_vars(mpf_home, descriptor, node_name): if ld_lib_path: ld_lib_path += ':' - executor_env['LD_LIBRARY_PATH'] = ld_lib_path + os.path.join(mpf_home, 'lib') + executor_env['LD_LIBRARY_PATH'] = ld_lib_path + str(mpf_home / 'lib') return executor_env # Expand environment variables and replace non-existent variables with an empty string. -def expand_env_vars(raw_str, env): +def expand_env_vars(raw_str: str, env: Dict[str, str]) -> str: # dict that returns empty string when key is missing. defaults = collections.defaultdict(str) # In the call to substitute the keyword arguments (**env) take precedence. diff --git a/components/component_registration.py b/components/component_registration.py new file mode 100644 index 00000000..71633291 --- /dev/null +++ b/components/component_registration.py @@ -0,0 +1,249 @@ +############################################################################# +# NOTICE # +# # +# This software (or technical data) was produced for the U.S. Government # +# under contract, and is subject to the Rights in Data-General Clause # +# 52.227-14, Alt. IV (DEC 2007). # +# # +# Copyright 2023 The MITRE Corporation. All Rights Reserved. # +############################################################################# + +############################################################################# +# Copyright 2023 The MITRE Corporation # +# # +# Licensed under the Apache License, Version 2.0 (the "License"); # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http://www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################# + +from __future__ import annotations + +import base64 +import errno +import http.client +import json +import socket +import ssl +import time +import urllib.error +import urllib.request +import urllib.response +from typing import Callable, Dict, NamedTuple, NoReturn, Tuple + + +def register_component(env_config, descriptor_bytes: bytes) -> None: + if env_config.oidc_issuer_uri: + OidcRegistration(env_config, descriptor_bytes + ).post_descriptor() + else: + BasicAuthRegistration(env_config, descriptor_bytes + ).post_descriptor() + + +def execute_http_request_with_retry(request_context: RequestContext) -> urllib.response.addinfourl: + url = request_context.url + while True: + request = request_context.request_builder(url) + try: + return _OPENER.open(request) + except http.client.BadStatusLine as err: + url = handle_bad_status(url, err) + except urllib.error.HTTPError as err: + url = handle_http_error(url, err, request_context) + except urllib.error.URLError as err: + url, should_wait = handle_url_error(url, err) + if should_wait: + time.sleep(10) + + +def handle_bad_status(url: str, error: http.client.BadStatusLine) -> str: + if url.startswith('https'): + raise error + new_url = url.replace('http://', 'https://') + print(f'HTTP request to {url} failed due to an invalid status line in the HTTP ' + 'response. This usually means that the server is using HTTPS, but an "http://" ' + 'URL was used. Trying again with:', new_url) + return new_url + + +def handle_http_error( + url: str, + error: urllib.error.HTTPError, + request_context: RequestContext) -> str: + if error.url != url: + print(f'Sending HTTP request to {url} resulted in a redirect to {error.url}.') + return error.url + + response_content = error.read() + try: + server_message = json.loads(response_content)['message'] + except (ValueError, KeyError): + server_message = response_content.decode() + + error_msg = f'The following error occurred while sending HTTP request to {url}: {error}' + if server_message: + error_msg += f' {server_message}' + if error.code == 401: + error_msg += f'\n{request_context.get_401_error_msg(error)}' + raise RuntimeError(error_msg) from error + + +_RETRYABLE_ERR_NOS = (socket.EAI_NONAME, socket.EAI_AGAIN, errno.ECONNREFUSED) + +def handle_url_error(url: str, error: urllib.error.URLError) -> Tuple[str, bool]: + reason = error.reason + is_unknown_protocol = isinstance(reason, ssl.SSLError) and reason.reason == 'UNKNOWN_PROTOCOL' + if is_unknown_protocol: + if url.startswith('http'): + raise error + new_url = url.replace('https://', 'http://') + print(f'HTTP request to {url} failed due to an "UNKNOWN_PROTOCOL" SSL ' + 'error. This usually means that the server is using HTTP on the specified ' + 'port, but an "https://" URL was used. Trying again with:', new_url) + return new_url, False + + if isinstance(reason, OSError) and reason.errno in _RETRYABLE_ERR_NOS: + print(f'HTTP request to {url} failed due to "{reason.strerror}". This is either ' + 'because the service is still starting up or the wrong URL was used. The ' + 'request will be re-attempted in 10 seconds.') + return url, True + raise error + + + +# The default urllib.request.HTTPRedirectHandler converts POST requests to GET requests. +# This subclass just throws an exception so we can post to the new URL ourselves. +class ThrowingRedirectHandler(urllib.request.HTTPRedirectHandler): + def http_error_302(self, req, fp, code, msg, headers) -> NoReturn: + new_url = headers.get('location') or headers.get('uri') + if new_url: + raise urllib.error.HTTPError(new_url, code, msg, headers, fp) + else: + raise RuntimeError('Received HTTP redirect response with no location header.') + + +def create_opener() -> urllib.request.OpenerDirector: + ssl_ctx = ssl.SSLContext() + return urllib.request.build_opener( + ThrowingRedirectHandler(), + urllib.request.HTTPSHandler(context=ssl_ctx)) + +_OPENER = create_opener() + + +class BasicAuthRegistration: + def __init__(self, env_config, descriptor_bytes: bytes): + self._wfm_user: str = env_config.wfm_user + self._wfm_password: str = env_config.wfm_password + self._wfm_base_url: str = env_config.wfm_base_url + self._descriptor_bytes: bytes = descriptor_bytes + + def post_descriptor(self) -> None: + url = self._wfm_base_url + '/rest/components/registerUnmanaged' + headers = create_basic_auth_header(self._wfm_user, self._wfm_password) + headers['Content-Type'] = 'application/json' + + print('Registering component by posting descriptor to', url) + request_context = RequestContext( + url, + lambda u: urllib.request.Request(u, self._descriptor_bytes, headers), + self.get_401_error_msg) + + with execute_http_request_with_retry(request_context): + # We don't need to do anything with the response. + pass + + @staticmethod + def get_401_error_msg(error: urllib.error.HTTPError): + if error.headers.get('WWW-Authenticate') == 'Bearer': + return ('Workflow Manager is configured to use OIDC, so the component must also be' + ' configured to use OIDC.') + else: + return 'The WFM_USER and WFM_PASSWORD environment variables need to be changed.' + + +class OidcRegistration: + def __init__(self, env_config, descriptor_bytes: bytes): + self._token_url: str = self._request_token_url(env_config.oidc_issuer_uri) + self._token: str = '' + self._reuse_token_until: float = 0.0 + self._wfm_base_url: str = env_config.wfm_base_url + self._client_id: str = env_config.oidc_client_id + self._client_secret: str = env_config.oidc_client_secret + self._descriptor_bytes: bytes = descriptor_bytes + self._request_auth_token() + + @classmethod + def _request_token_url(cls, oidc_issuer_uri: str) -> str: + config_url = oidc_issuer_uri + '/.well-known/openid-configuration' + print('Getting OIDC configuration metadata from', config_url) + request_context = RequestContext(config_url, urllib.request.Request, cls.get_401_error_msg) + with execute_http_request_with_retry(request_context) as resp: + return json.load(resp)['token_endpoint'] + + def _request_auth_token(self) -> None: + headers = create_basic_auth_header(self._client_id, self._client_secret) + + def create_request(url): + # Update token url in case there was a redirect. + self._token_url = url + return urllib.request.Request(url, b'grant_type=client_credentials', headers) + + print(f'Requesting token from {self._token_url}') + request_context = RequestContext(self._token_url, create_request, self.get_401_error_msg) + with execute_http_request_with_retry(request_context) as resp: + resp_content = json.load(resp) + + self._token = resp_content['access_token'] + expires_in = resp_content['expires_in'] + self._reuse_token_until = time.time() + expires_in + if expires_in > 60: + # Do not re-use token for full duration to account for clock drift and network latency. + self._reuse_token_until -= 60 + print(f'Received token that expires in {expires_in} seconds.') + + def post_descriptor(self) -> None: + url = self._wfm_base_url + '/rest/components/registerUnmanaged' + print('Registering component by posting descriptor to', url) + request_context = RequestContext( + url, self._create_post_descriptor_request, self.get_401_error_msg) + with execute_http_request_with_retry(request_context): + # We don't need to do anything with the response. + pass + + def _create_post_descriptor_request(self, url: str) -> urllib.request.Request: + if time.time() > self._reuse_token_until: + self._request_auth_token() + headers = { + 'Authorization': f'Bearer {self._token}', + 'Content-Type': 'application/json' + } + return urllib.request.Request(url, self._descriptor_bytes, headers) + + @staticmethod + def get_401_error_msg(error: urllib.error.HTTPError): + base_message = 'The OIDC environment variables need to be changed.' + if auth_header := error.headers.get('WWW-Authenticate'): + return f'The WWW-Authenticate header contained: {auth_header}\n{base_message}' + else: + return base_message + + +def create_basic_auth_header(user: str, password: str) -> Dict[str, str]: + auth_info_bytes = f'{user}:{password}'.encode() + base64_auth_info = base64.b64encode(auth_info_bytes).decode() + return {'Authorization': f'Basic {base64_auth_info}'} + + +class RequestContext(NamedTuple): + url: str + request_builder: Callable[[str], urllib.request.Request] + get_401_error_msg: Callable[[urllib.error.HTTPError], str] diff --git a/components/cpp_executor/Dockerfile b/components/cpp_executor/Dockerfile index 3ceaba2d..7c9b3022 100644 --- a/components/cpp_executor/Dockerfile +++ b/components/cpp_executor/Dockerfile @@ -91,7 +91,7 @@ ENV BUILD_DIR /home/mpf/component_build ENV PLUGINS_DIR $MPF_HOME/plugins -COPY docker-entrypoint.sh component-executor.py /scripts/ +COPY docker-entrypoint.sh component-executor.py component_registration.py /scripts/ COPY cli_runner/*.py cli_runner/Log4cxxConfig.xml /scripts/cli_runner/ diff --git a/components/java_executor/Dockerfile b/components/java_executor/Dockerfile index e83651f1..c0453f38 100644 --- a/components/java_executor/Dockerfile +++ b/components/java_executor/Dockerfile @@ -56,7 +56,7 @@ ENV PLUGINS_DIR $MPF_HOME/plugins COPY --from=openmpf_build /build-artifacts/java-executor/*.jar $MPF_HOME/jars/ -COPY component-executor.py /scripts/component-executor.py +COPY component-executor.py component_registration.py /scripts/ ENTRYPOINT ["python3", "-u", "/scripts/component-executor.py"] diff --git a/components/python/Dockerfile b/components/python/Dockerfile index 4b54b460..31eb20b8 100644 --- a/components/python/Dockerfile +++ b/components/python/Dockerfile @@ -111,7 +111,7 @@ ENV MPF_LOG_PATH $MPF_HOME/share/logs ENV PLUGINS_DIR $MPF_HOME/plugins -COPY docker-entrypoint.sh component-executor.py /scripts/ +COPY docker-entrypoint.sh component-executor.py component_registration.py /scripts/ COPY cli_runner/*.py /scripts/cli_runner/