diff --git a/.gitignore b/.gitignore index 5a216522f..7ef392c5e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,4 @@ error pylint.out __pycache__/ build/ -testing/unit_test/temp \ No newline at end of file +testing/unit_test/temp/ diff --git a/framework/python/src/test_orc/test_orchestrator.py b/framework/python/src/test_orc/test_orchestrator.py index 74e399df1..61b94a995 100644 --- a/framework/python/src/test_orc/test_orchestrator.py +++ b/framework/python/src/test_orc/test_orchestrator.py @@ -27,6 +27,7 @@ RUNTIME_DIR = "runtime/test" TEST_MODULES_DIR = "modules/test" MODULE_CONFIG = "conf/module_config.json" +DEVICE_ROOT_CERTS = "local/root_certs" class TestOrchestrator: @@ -61,6 +62,9 @@ def start(self): os.makedirs(RUNTIME_DIR, exist_ok=True) util.run_command(f"chown -R {self._host_user} {RUNTIME_DIR}") + # Setup the root_certs folder + os.makedirs(DEVICE_ROOT_CERTS, exist_ok=True) + self._load_test_modules() self.build_test_modules() diff --git a/local/.gitignore b/local/.gitignore index f13ce8d85..d3086d4df 100644 --- a/local/.gitignore +++ b/local/.gitignore @@ -1,3 +1,3 @@ -system.json -devices -root_certs \ No newline at end of file +system.json +devices +root_certs diff --git a/modules/test/base/base.Dockerfile b/modules/test/base/base.Dockerfile index 62ff54d6c..707136f6d 100644 --- a/modules/test/base/base.Dockerfile +++ b/modules/test/base/base.Dockerfile @@ -17,10 +17,14 @@ FROM ubuntu:jammy ARG MODULE_NAME=base ARG MODULE_DIR=modules/test/$MODULE_NAME +ARG COMMON_DIR=framework/python/src/common # Install common software RUN apt-get update && apt-get install -y net-tools iputils-ping tcpdump iproute2 jq python3 python3-pip dos2unix nmap --fix-missing +# Install common python modules +COPY $COMMON_DIR/ /testrun/python/src/common + # Setup the base python requirements COPY $MODULE_DIR/python /testrun/python diff --git a/modules/test/base/python/requirements.txt b/modules/test/base/python/requirements.txt index 9c4e2b056..9d9473d74 100644 --- a/modules/test/base/python/requirements.txt +++ b/modules/test/base/python/requirements.txt @@ -1,2 +1,3 @@ grpcio -grpcio-tools \ No newline at end of file +grpcio-tools +netifaces \ No newline at end of file diff --git a/modules/test/base/python/src/test_module.py b/modules/test/base/python/src/test_module.py index e949976fa..8bee611b9 100644 --- a/modules/test/base/python/src/test_module.py +++ b/modules/test/base/python/src/test_module.py @@ -100,7 +100,10 @@ def run_tests(self): if isinstance(result, bool): test['result'] = 'compliant' if result else 'non-compliant' else: - test['result'] = 'compliant' if result[0] else 'non-compliant' + if result[0] is None: + test['result'] = 'skipped' + else: + test['result'] = 'compliant' if result[0] else 'non-compliant' test['result_details'] = result[1] else: test['result'] = 'skipped' diff --git a/modules/test/conn/python/requirements.txt b/modules/test/conn/python/requirements.txt index 93b351f44..2b8d18750 100644 --- a/modules/test/conn/python/requirements.txt +++ b/modules/test/conn/python/requirements.txt @@ -1 +1 @@ -scapy \ No newline at end of file +pyOpenSSL \ No newline at end of file diff --git a/modules/test/tls/bin/check_cert_signature.sh b/modules/test/tls/bin/check_cert_signature.sh new file mode 100644 index 000000000..ebd4a7549 --- /dev/null +++ b/modules/test/tls/bin/check_cert_signature.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +ROOT_CERT=$1 +DEVICE_CERT=$2 + +echo "ROOT: $ROOT_CERT" +echo "DEVICE_CERT: $DEVICE_CERT" + +response=$(openssl verify -CAfile $ROOT_CERT $DEVICE_CERT) + +echo "$response" diff --git a/modules/test/tls/bin/get_ciphers.sh b/modules/test/tls/bin/get_ciphers.sh new file mode 100644 index 000000000..e82bbc180 --- /dev/null +++ b/modules/test/tls/bin/get_ciphers.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +CAPTURE_FILE=$1 +DST_IP=$2 +DST_PORT=$3 + +TSHARK_FILTER="ssl.handshake.ciphersuites and ip.dst==$DST_IP and tcp.dstport==$DST_PORT" +response=$(tshark -r $CAPTURE_FILE -Y "$TSHARK_FILTER" -Vx | grep 'Cipher Suite:' | awk '{$1=$1};1' | sed 's/Cipher Suite: //') + +echo "$response" diff --git a/modules/test/tls/bin/get_client_hello_packets.sh b/modules/test/tls/bin/get_client_hello_packets.sh new file mode 100644 index 000000000..13e42f791 --- /dev/null +++ b/modules/test/tls/bin/get_client_hello_packets.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +CAPTURE_FILE=$1 +SRC_IP=$2 +TLS_VERSION=$3 + +TSHARK_OUTPUT="-T json -e ip.src -e tcp.dstport -e ip.dst" +TSHARK_FILTER="ssl.handshake.type==1 and ip.src==$SRC_IP" + +if [[ $TLS_VERSION == '1.2' || -z $TLS_VERSION ]];then + TSHARK_FILTER=$TSHARK_FILTER "and ssl.handshake.version==0x0303" +elif [ $TLS_VERSION == '1.2' ];then + TSHARK_FILTER=$TSHARK_FILTER "and ssl.handshake.version==0x0304" +fi + +response=$(tshark -r $CAPTURE_FILE $TSHARK_OUTPUT $TSHARK_FILTER) + +echo "$response" + \ No newline at end of file diff --git a/modules/test/tls/bin/get_handshake_complete.sh b/modules/test/tls/bin/get_handshake_complete.sh new file mode 100644 index 000000000..de1eb887d --- /dev/null +++ b/modules/test/tls/bin/get_handshake_complete.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +CAPTURE_FILE=$1 +SRC_IP=$2 +DST_IP=$3 +TLS_VERSION=$4 + +TSHARK_FILTER="ip.src==$SRC_IP and ip.dst==$DST_IP " + +if [[ $TLS_VERSION == '1.2' || -z $TLS_VERSION ]];then + TSHARK_FILTER=$TSHARK_FILTER " and ssl.handshake.type==2 and tls.handshake.type==14 " +elif [ $TLS_VERSION == '1.2' ];then + TSHARK_FILTER=$TSHARK_FILTER "and ssl.handshake.type==2 and tls.handshake.extensions.supported_version==0x0304" +fi + +response=$(tshark -r $CAPTURE_FILE $TSHARK_FILTER) + +echo "$response" + \ No newline at end of file diff --git a/modules/test/tls/bin/start_test_module b/modules/test/tls/bin/start_test_module new file mode 100644 index 000000000..d8cede486 --- /dev/null +++ b/modules/test/tls/bin/start_test_module @@ -0,0 +1,56 @@ +#!/bin/bash + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# An example startup script that does the bare minimum to start +# a test module via a pyhon script. Each test module should include a +# start_test_module file that overwrites this one to boot all of its +# specific requirements to run. + +# Define where the python source files are located +PYTHON_SRC_DIR=/testrun/python/src + +# Fetch module name +MODULE_NAME=$1 + +# Default interface should be veth0 for all containers +DEFAULT_IFACE=veth0 + +# Allow a user to define an interface by passing it into this script +DEFINED_IFACE=$2 + +# Select which interace to use +if [[ -z $DEFINED_IFACE || "$DEFINED_IFACE" == "null" ]] +then + echo "No interface defined, defaulting to veth0" + INTF=$DEFAULT_IFACE +else + INTF=$DEFINED_IFACE +fi + +# Create and set permissions on the log files +LOG_FILE=/runtime/output/$MODULE_NAME.log +RESULT_FILE=/runtime/output/$MODULE_NAME-result.json +touch $LOG_FILE +touch $RESULT_FILE +chown $HOST_USER $LOG_FILE +chown $HOST_USER $RESULT_FILE + +# Run the python scrip that will execute the tests for this module +# -u flag allows python print statements +# to be logged by docker by running unbuffered +python3 -u $PYTHON_SRC_DIR/run.py "-m $MODULE_NAME" + +echo Module has finished \ No newline at end of file diff --git a/modules/test/tls/conf/module_config.json b/modules/test/tls/conf/module_config.json new file mode 100644 index 000000000..59e5a839d --- /dev/null +++ b/modules/test/tls/conf/module_config.json @@ -0,0 +1,37 @@ +{ + "config": { + "meta": { + "name": "tls", + "display_name": "TLS", + "description": "TLS tests" + }, + "network": true, + "docker": { + "depends_on": "base", + "enable_container": true, + "timeout": 300 + }, + "tests":[ + { + "name": "security.tls.v1_2_server", + "description": "Check the device web server TLS 1.2 & certificate is valid", + "expected_behavior": "TLS 1.2 certificate is issued to the web browser client when accessed" + }, + { + "name": "security.tls.v1_3_server", + "description": "Check the device web server TLS 1.3 & certificate is valid", + "expected_behavior": "TLS 1.3 certificate is issued to the web browser client when accessed" + }, + { + "name": "security.tls.v1_2_client", + "description": "Device uses TLS with connection to an external service on port 443 (or any other port which could be running the webserver-HTTPS)", + "expected_behavior": "The packet indicates a TLS connection with at least TLS 1.2 and support for ECDH and ECDSA ciphers" + }, + { + "name": "security.tls.v1_3_client", + "description": "Device uses TLS with connection to an external service on port 443 (or any other port which could be running the webserver-HTTPS)", + "expected_behavior": "The packet indicates a TLS connection with at least TLS 1.3" + } + ] + } +} \ No newline at end of file diff --git a/modules/test/tls/python/requirements.txt b/modules/test/tls/python/requirements.txt new file mode 100644 index 000000000..432116ff2 --- /dev/null +++ b/modules/test/tls/python/requirements.txt @@ -0,0 +1,2 @@ +cryptography +pyOpenSSL \ No newline at end of file diff --git a/modules/test/tls/python/src/run.py b/modules/test/tls/python/src/run.py new file mode 100644 index 000000000..51bc82f8f --- /dev/null +++ b/modules/test/tls/python/src/run.py @@ -0,0 +1,68 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Run Baseline module""" +import argparse +import signal +import sys +import logger + +from tls_module import TLSModule + +LOGGER = logger.get_logger('test_module') +RUNTIME = 1500 + + +class TLSModuleRunner: + """An example runner class for test modules.""" + + def __init__(self, module): + + signal.signal(signal.SIGINT, self._handler) + signal.signal(signal.SIGTERM, self._handler) + signal.signal(signal.SIGABRT, self._handler) + signal.signal(signal.SIGQUIT, self._handler) + + LOGGER.info('Starting TLS Module') + + self._test_module = TLSModule(module) + self._test_module.run_tests() + + def _handler(self, signum): + LOGGER.debug('SigtermEnum: ' + str(signal.SIGTERM)) + LOGGER.debug('Exit signal received: ' + str(signum)) + if signum in (2, signal.SIGTERM): + LOGGER.info('Exit signal received. Stopping test module...') + LOGGER.info('Test module stopped') + sys.exit(1) + + +def run(): + parser = argparse.ArgumentParser( + description='Security Module Help', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument( + '-m', + '--module', + help='Define the module name to be used to create the log file') + + args = parser.parse_args() + + # For some reason passing in the args from bash adds an extra + # space before the argument so we'll just strip out extra space + TLSModuleRunner(args.module.strip()) + + +if __name__ == '__main__': + run() diff --git a/modules/test/tls/python/src/tls_module.py b/modules/test/tls/python/src/tls_module.py new file mode 100644 index 000000000..d58163266 --- /dev/null +++ b/modules/test/tls/python/src/tls_module.py @@ -0,0 +1,108 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Baseline test module""" +from test_module import TestModule +from tls_util import TLSUtil + +LOG_NAME = 'test_tls' +LOGGER = None +STARTUP_CAPTURE_FILE = '/runtime/device/startup.pcap' +MONITOR_CAPTURE_FILE = '/runtime/device/monitor.pcap' + + +class TLSModule(TestModule): + """An example testing module.""" + + def __init__(self, module): + super().__init__(module_name=module, log_name=LOG_NAME) + global LOGGER + LOGGER = self._get_logger() + self._tls_util = TLSUtil(LOGGER) + + def _security_tls_v1_2_server(self): + LOGGER.info('Running security.tls.v1_2_server') + self._resolve_device_ip() + # If the ipv4 address wasn't resolved yet, try again + if self._device_ipv4_addr is not None: + tls_1_2_results = self._tls_util.validate_tls_server( + self._device_ipv4_addr, tls_version='1.2') + tls_1_3_results = self._tls_util.validate_tls_server( + self._device_ipv4_addr, tls_version='1.3') + return self._tls_util.process_tls_server_results(tls_1_2_results, + tls_1_3_results) + else: + LOGGER.error('Could not resolve device IP address. Skipping') + return None, 'Could not resolve device IP address. Skipping' + + def _security_tls_v1_3_server(self): + LOGGER.info('Running security.tls.v1_3_server') + self._resolve_device_ip() + # If the ipv4 address wasn't resolved yet, try again + if self._device_ipv4_addr is not None: + return self._tls_util.validate_tls_server(self._device_ipv4_addr, + tls_version='1.3') + else: + LOGGER.error('Could not resolve device IP address. Skipping') + return None, 'Could not resolve device IP address. Skipping' + + def _security_tls_v1_2_client(self): + LOGGER.info('Running security.tls.v1_2_client') + self._resolve_device_ip() + # If the ipv4 address wasn't resolved yet, try again + if self._device_ipv4_addr is not None: + return self._validate_tls_client(self._device_ipv4_addr, '1.2') + else: + LOGGER.error('Could not resolve device IP address. Skipping') + return None, 'Could not resolve device IP address. Skipping' + + def _security_tls_v1_3_client(self): + LOGGER.info('Running security.tls.v1_3_client') + self._resolve_device_ip() + # If the ipv4 address wasn't resolved yet, try again + if self._device_ipv4_addr is not None: + return self._validate_tls_client(self._device_ipv4_addr, '1.3') + else: + LOGGER.error('Could not resolve device IP address. Skipping') + return None, 'Could not resolve device IP address. Skipping' + + def _validate_tls_client(self, client_ip, tls_version): + monitor_result = self._tls_util.validate_tls_client( + client_ip=client_ip, + tls_version=tls_version, + capture_file=MONITOR_CAPTURE_FILE) + startup_result = self._tls_util.validate_tls_client( + client_ip=client_ip, + tls_version=tls_version, + capture_file=STARTUP_CAPTURE_FILE) + + LOGGER.info('Montor: ' + str(monitor_result)) + LOGGER.info('Startup: ' + str(startup_result)) + + if (not monitor_result[0] and monitor_result[0] is not None) or ( + not startup_result[0] and startup_result[0] is not None): + result = False, startup_result[1] + monitor_result[1] + elif monitor_result[0] and startup_result[0]: + result = True, startup_result[1] + monitor_result[1] + elif monitor_result[0] and startup_result[0] is None: + result = True, monitor_result[1] + elif startup_result[0] and monitor_result[0] is None: + result = True, monitor_result[1] + else: + result = None, startup_result[1] + return result + + def _resolve_device_ip(self): + # If the ipv4 address wasn't resolved yet, try again + if self._device_ipv4_addr is None: + self._device_ipv4_addr = self._get_device_ipv4() diff --git a/modules/test/tls/python/src/tls_module_test.py b/modules/test/tls/python/src/tls_module_test.py new file mode 100644 index 000000000..84a1c70eb --- /dev/null +++ b/modules/test/tls/python/src/tls_module_test.py @@ -0,0 +1,268 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module run all the TLS related unit tests""" +from tls_util import TLSUtil +import unittest +from common import logger +from scapy.all import sniff, wrpcap +import os +import threading +import time +import netifaces +import ssl +import http.client + +CAPTURE_DIR = 'testing/unit_test/temp' +MODULE_NAME = 'tls_module_test' +TLS_UTIL = None +PACKET_CAPTURE = None + + +class TLSModuleTest(unittest.TestCase): + """Contains and runs all the unit tests concerning TLS behaviors""" + @classmethod + def setUpClass(cls): + log = logger.get_logger(MODULE_NAME) + global TLS_UTIL + TLS_UTIL = TLSUtil(log, + bin_dir='modules/test/tls/bin', + cert_out_dir='testing/unit_test/temp', + root_certs_dir='local/root_certs') + + # Test 1.2 server when only 1.2 connection is established + def security_tls_v1_2_server_test(self): + tls_1_2_results = TLS_UTIL.validate_tls_server('google.com', + tls_version='1.2') + tls_1_3_results = None, 'No TLS 1.3' + test_results = TLS_UTIL.process_tls_server_results(tls_1_2_results, + tls_1_3_results) + self.assertTrue(test_results[0]) + + # Test 1.2 server when 1.3 connection is established + def security_tls_v1_2_for_1_3_server_test(self): + tls_1_2_results = None, 'No TLS 1.2' + tls_1_3_results = TLS_UTIL.validate_tls_server('google.com', + tls_version='1.3') + test_results = TLS_UTIL.process_tls_server_results(tls_1_2_results, + tls_1_3_results) + self.assertTrue(test_results[0]) + + # Test 1.2 server when 1.2 and 1.3 connection is established + def security_tls_v1_2_for_1_2_and_1_3_server_test(self): + tls_1_2_results = TLS_UTIL.validate_tls_server('google.com', + tls_version='1.2') + tls_1_3_results = TLS_UTIL.validate_tls_server('google.com', + tls_version='1.3') + test_results = TLS_UTIL.process_tls_server_results(tls_1_2_results, + tls_1_3_results) + self.assertTrue(test_results[0]) + + # Test 1.2 server when 1.2 and failed 1.3 connection is established + def security_tls_v1_2_for_1_2_and_1_3_fail_server_test(self): + tls_1_2_results = TLS_UTIL.validate_tls_server('google.com', + tls_version='1.2') + tls_1_3_results = False, 'Signature faild' + test_results = TLS_UTIL.process_tls_server_results(tls_1_2_results, + tls_1_3_results) + self.assertTrue(test_results[0]) + + # Test 1.2 server when 1.3 and failed 1.2 connection is established + def security_tls_v1_2_for_1_3_and_1_2_fail_server_test(self): + tls_1_3_results = TLS_UTIL.validate_tls_server('google.com', + tls_version='1.3') + tls_1_2_results = False, 'Signature faild' + test_results = TLS_UTIL.process_tls_server_results(tls_1_2_results, + tls_1_3_results) + self.assertTrue(test_results[0]) + + # Test 1.2 server when 1.3 and 1.2 failed connection is established + def security_tls_v1_2_fail_server_test(self): + tls_1_2_results = False, 'Signature faild' + tls_1_3_results = False, 'Signature faild' + test_results = TLS_UTIL.process_tls_server_results(tls_1_2_results, + tls_1_3_results) + self.assertFalse(test_results[0]) + + # Test 1.2 server when 1.3 and 1.2 failed connection is established + def security_tls_v1_2_none_server_test(self): + tls_1_2_results = None, 'No cert' + tls_1_3_results = None, 'No cert' + test_results = TLS_UTIL.process_tls_server_results(tls_1_2_results, + tls_1_3_results) + self.assertIsNone(test_results[0]) + + def security_tls_v1_3_server_test(self): + test_results = TLS_UTIL.validate_tls_server('google.com', tls_version='1.3') + self.assertTrue(test_results[0]) + + def security_tls_v1_2_client_test(self): + test_results = self.test_client_tls('1.2') + print(str(test_results)) + self.assertTrue(test_results[0]) + + def security_tls_v1_2_client_cipher_fail_test(self): + test_results = self.test_client_tls('1.2', disable_valid_ciphers=True) + print(str(test_results)) + self.assertFalse(test_results[0]) + + def security_tls_client_skip_test(self): + # 1.1 will fail to connect and so no hello client will exist + # which should result in a skip result + test_results = self.test_client_tls('1.2', tls_generate='1.1') + print(str(test_results)) + self.assertIsNone(test_results[0]) + + def security_tls_v1_3_client_test(self): + test_results = self.test_client_tls('1.3') + print(str(test_results)) + self.assertTrue(test_results[0]) + + def client_hello_packets_test(self): + packet_fail = {'dst_ip': '10.10.10.1', 'src_ip': '10.10.10.14', 'dst_port': '443', 'cipher_support': {'ecdh': False, 'ecdsa': True}} + packet_success = {'dst_ip': '10.10.10.1', 'src_ip': '10.10.10.14', 'dst_port': '443', 'cipher_support': {'ecdh': True, 'ecdsa': True}} + hello_packets = [packet_fail,packet_success] + hello_results = TLS_UTIL.process_hello_packets(hello_packets,'1.2') + print("Hello packets test results: " + str(hello_results)) + expected = {'valid':[packet_success],'invalid':[]} + self.assertEqual(hello_results,expected) + + def test_client_tls(self, + tls_version, + tls_generate=None, + disable_valid_ciphers=False): + # Make the capture file + os.makedirs(CAPTURE_DIR, exist_ok=True) + capture_file = CAPTURE_DIR + '/client_tls.pcap' + + # Resolve the client ip used + client_ip = self.get_interface_ip('eth0') + + # Genrate TLS outbound traffic + if tls_generate is None: + tls_generate = tls_version + self.generate_tls_traffic(capture_file, tls_generate, disable_valid_ciphers) + + # Run the client test + return TLS_UTIL.validate_tls_client(client_ip=client_ip, + tls_version=tls_version, + capture_file=capture_file) + + def generate_tls_traffic(self, + capture_file, + tls_version, + disable_valid_ciphers=False): + capture_thread = self.start_capture_thread(10) + print('Capture Started') + + # Generate some TLS 1.2 outbound traffic + while capture_thread.is_alive(): + self.make_tls_connection('www.google.com', 443, tls_version, + disable_valid_ciphers) + time.sleep(1) + + # Save the captured packets to the file. + wrpcap(capture_file, PACKET_CAPTURE) + + def make_tls_connection(self, + hostname, + port, + tls_version, + disable_valid_ciphers=False): + # Create the SSL context with the desired TLS version and options + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + context.options |= ssl.PROTOCOL_TLS + + if disable_valid_ciphers: + # Create a list of ciphers that do not use ECDH or ECDSA + ciphers_str = [ + 'TLS_AES_256_GCM_SHA384', 'TLS_CHACHA20_POLY1305_SHA256', + 'TLS_AES_128_GCM_SHA256', 'AES256-GCM-SHA384', + 'PSK-AES256-GCM-SHA384', 'PSK-CHACHA20-POLY1305', + 'RSA-PSK-AES128-GCM-SHA256', 'DHE-PSK-AES128-GCM-SHA256', + 'AES128-GCM-SHA256', 'PSK-AES128-GCM-SHA256', 'AES256-SHA256', + 'AES128-SHA' + ] + context.set_ciphers(':'.join(ciphers_str)) + + if tls_version != '1.1': + context.options |= ssl.OP_NO_TLSv1 # Disable TLS 1.0 + context.options |= ssl.OP_NO_TLSv1_1 # Disable TLS 1.1 + else: + context.options |= ssl.OP_NO_TLSv1_2 # Disable TLS 1.2 + context.options |= ssl.OP_NO_TLSv1_3 # Disable TLS 1.3 + + if tls_version == '1.3': + context.options |= ssl.OP_NO_TLSv1_2 # Disable TLS 1.2 + elif tls_version == '1.2': + context.options |= ssl.OP_NO_TLSv1_3 # Disable TLS 1.3 + + # Create the HTTPS connection with the SSL context + connection = http.client.HTTPSConnection(hostname, port, context=context) + + # Perform the TLS handshake manually + try: + connection.connect() + except ssl.SSLError as e: + print('Failed to make connection: ' + str(e)) + + # At this point, the TLS handshake is complete. + # You can do any further processing or just close the connection. + connection.close() + + def start_capture(self, timeout): + global PACKET_CAPTURE + PACKET_CAPTURE = sniff(iface='eth0', timeout=timeout) + + def start_capture_thread(self, timeout): + # Start the packet capture in a separate thread to avoid blocking. + capture_thread = threading.Thread(target=self.start_capture, + args=(timeout, )) + capture_thread.start() + + return capture_thread + + def get_interface_ip(self, interface_name): + try: + addresses = netifaces.ifaddresses(interface_name) + ipv4 = addresses[netifaces.AF_INET][0]['addr'] + return ipv4 + except (ValueError, KeyError) as e: + print(f'Error: {e}') + return None + + +if __name__ == '__main__': + suite = unittest.TestSuite() + suite.addTest(TLSModuleTest('client_hello_packets_test')) + # TLS 1.2 server tests + suite.addTest(TLSModuleTest('security_tls_v1_2_server_test')) + suite.addTest(TLSModuleTest('security_tls_v1_2_for_1_3_server_test')) + suite.addTest(TLSModuleTest('security_tls_v1_2_for_1_2_and_1_3_server_test')) + suite.addTest( + TLSModuleTest('security_tls_v1_2_for_1_2_and_1_3_fail_server_test')) + suite.addTest( + TLSModuleTest('security_tls_v1_2_for_1_3_and_1_2_fail_server_test')) + suite.addTest(TLSModuleTest('security_tls_v1_2_fail_server_test')) + suite.addTest(TLSModuleTest('security_tls_v1_2_none_server_test')) + # # TLS 1.3 server tests + suite.addTest(TLSModuleTest('security_tls_v1_3_server_test')) + # TLS client tests + suite.addTest(TLSModuleTest('security_tls_v1_2_client_test')) + suite.addTest(TLSModuleTest('security_tls_v1_3_client_test')) + suite.addTest(TLSModuleTest('security_tls_client_skip_test')) + suite.addTest(TLSModuleTest('security_tls_v1_2_client_cipher_fail_test')) + runner = unittest.TextTestRunner() + runner.run(suite) diff --git a/modules/test/tls/python/src/tls_util.py b/modules/test/tls/python/src/tls_util.py new file mode 100644 index 000000000..c83c131af --- /dev/null +++ b/modules/test/tls/python/src/tls_util.py @@ -0,0 +1,393 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module that contains various metehods for validating TLS communications""" +import ssl +import socket +from datetime import datetime +from OpenSSL import crypto +import json +import os +from common import util + +LOG_NAME = 'tls_util' +LOGGER = None +DEFAULT_BIN_DIR = '/testrun/bin' +DEFAULT_CERTS_OUT_DIR = '/runtime/output' +DEFAULT_ROOT_CERTS_DIR = '/testrun/root_certs' + + +class TLSUtil(): + """Helper class for various tests concerning TLS communications""" + + def __init__(self, + logger, + bin_dir=DEFAULT_BIN_DIR, + cert_out_dir=DEFAULT_CERTS_OUT_DIR, + root_certs_dir=DEFAULT_ROOT_CERTS_DIR): + global LOGGER + LOGGER = logger + self._bin_dir = bin_dir + self._dev_cert_file = cert_out_dir + '/device_cert.crt' + self._root_certs_dir = root_certs_dir + + def get_public_certificate(self, + host, + port=443, + validate_cert=False, + tls_version='1.2'): + try: + #context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = False + if not validate_cert: + # Disable certificate verification + context.verify_mode = ssl.CERT_NONE + else: + # Use host CA certs for validation + context.load_default_certs() + context.verify_mode = ssl.CERT_REQUIRED + + # Set the correct TLS version + context.options |= ssl.PROTOCOL_TLS + context.options |= ssl.OP_NO_TLSv1 # Disable TLS 1.0 + context.options |= ssl.OP_NO_TLSv1_1 # Disable TLS 1.1 + if tls_version == '1.3': + context.options |= ssl.OP_NO_TLSv1_2 # Disable TLS 1.2 + elif tls_version == '1.2': + context.options |= ssl.OP_NO_TLSv1_3 # Disable TLS 1.3 + + # Create an SSL/TLS socket + with socket.create_connection((host, port), timeout=5) as sock: + with context.wrap_socket(sock, server_hostname=host) as secure_sock: + # Get the server's certificate in PEM format + cert_pem = ssl.DER_cert_to_PEM_cert(secure_sock.getpeercert(True)) + + except ConnectionRefusedError: + LOGGER.info(f'Connection to {host}:{port} was refused.') + return None + except socket.gaierror: + LOGGER.info(f'Failed to resolve the hostname {host}.') + return None + except ssl.SSLError as e: + LOGGER.info(f'SSL error occurred: {e}') + return None + + return cert_pem + + def get_public_key(self, public_cert): + # Extract and return the public key from the certificate + public_key = public_cert.get_pubkey() + return public_key + + def verify_certificate_timerange(self, public_cert): + # Extract the notBefore and notAfter dates from the certificate + not_before = datetime.strptime(public_cert.get_notBefore().decode(), + '%Y%m%d%H%M%SZ') + not_after = datetime.strptime(public_cert.get_notAfter().decode(), + '%Y%m%d%H%M%SZ') + + LOGGER.info('Certificate valid from: ' + str(not_before) + ' To ' + + str(not_after)) + + # Get the current date + current_date = datetime.utcnow() + + # Check if today's date is within the certificate's validity range + if not_before <= current_date <= not_after: + return True, 'Certificate has a valid time range' + elif current_date <= not_before: + return False, 'Certificate is not yet valid' + else: + return False, 'Certificate has expired' + + def verify_public_key(self, public_key): + + # Get the key length based bits + key_length = public_key.bits() + LOGGER.info('Key Length: ' + str(key_length)) + + # Check the key type + key_type = 'Unknown' + if public_key.type() == crypto.TYPE_RSA: + key_type = 'RSA' + elif public_key.type() == crypto.TYPE_EC: + key_type = 'EC' + elif public_key.type() == crypto.TYPE_DSA: + key_type = 'DSA' + elif public_key.type() == crypto.TYPE_DH: + key_type = 'Diffie-Hellman' + LOGGER.info('Key Type: ' + key_type) + + # Check if the public key is of RSA type + if key_type == 'RSA': + if key_length >= 2048: + return True, 'RSA key length passed: ' + str(key_length) + ' >= 2048' + else: + return False, 'RSA key length too short: ' + str(key_length) + ' < 2048' + + # Check if the public key is of EC type + elif key_type == 'EC': + if key_length >= 224: + return True, 'EC key length passed: ' + str(key_length) + ' >= 224' + else: + return False, 'EC key length too short: ' + str(key_length) + ' < 224' + else: + return False, 'Key is not RSA or EC type' + + def validate_signature(self, host): + # Reconnect to the device but with validate signature option + # set to true which will check for proper cert chains + # within the valid CA root certs stored on the server + LOGGER.info( + 'Checking for valid signature from authorized Certificate Authorities') + public_cert = self.get_public_certificate(host, + validate_cert=True, + tls_version='1.2') + if public_cert: + LOGGER.info('Authorized Certificate Authority signature confirmed') + return True, 'Authorized Certificate Authority signature confirmed' + else: + LOGGER.info('Authorized Certificate Authority signature not present') + LOGGER.info('Resolving configured root certificates') + bin_file = self._bin_dir + '/check_cert_signature.sh' + # Get a list of all root certificates + root_certs = os.listdir(self._root_certs_dir) + LOGGER.info('Root Certs Found: ' + str(len(root_certs))) + for root_cert in root_certs: + try: + # Create the file path + root_cert_path = os.path.join(self._root_certs_dir, root_cert) + LOGGER.info('Checking root cert: ' + str(root_cert_path)) + args = f'{root_cert_path} {self._dev_cert_file}' + command = f'{bin_file} {args}' + response = util.run_command(command) + if 'device_cert.crt: OK' in str(response): + LOGGER.info('Device signed by cert:' + root_cert) + return True, 'Device signed by cert:' + root_cert + else: + LOGGER.info('Device not signed by cert: ' + root_cert) + except Exception as e: # pylint: disable=W0718 + LOGGER.error('Failed to check cert:' + root_cert) + LOGGER.error(str(e)) + return False, 'Device certificate has not been signed' + + def process_tls_server_results(self, tls_1_2_results, tls_1_3_results): + results = '' + if tls_1_2_results[0] is None and tls_1_3_results[0]: + results = True, 'TLS 1.3 validated:\n' + tls_1_3_results[1] + elif tls_1_3_results[0] is None and tls_1_2_results[0]: + results = True, 'TLS 1.2 validated:\n' + tls_1_2_results[1] + elif tls_1_2_results[0] and tls_1_3_results[0]: + description = 'TLS 1.2 validated:\n' + tls_1_2_results[1] + description += '\nTLS 1.3 validated:\n' + tls_1_3_results[1] + results = True, description + elif tls_1_2_results[0] and not tls_1_3_results[0]: + description = 'TLS 1.2 validated:\n' + tls_1_2_results[1] + description += '\nTLS 1.3 not validated:\n' + tls_1_3_results[1] + results = True, description + elif tls_1_3_results[0] and not tls_1_2_results[0]: + description = 'TLS 1.2 not validated:\n' + tls_1_2_results[1] + description += '\nTLS 1.3 validated:\n' + tls_1_3_results[1] + results = True, description + elif not tls_1_3_results[0] and not tls_1_2_results[0] and tls_1_2_results[ + 0] is not None and tls_1_3_results is not None: + description = 'TLS 1.2 not validated:\n' + tls_1_2_results[1] + description += '\nTLS 1.3 not validated:\n' + tls_1_3_results[1] + results = False, description + else: + description = 'TLS 1.2 not validated:\n' + tls_1_2_results[1] + description += '\nTLS 1.3 not validated:\n' + tls_1_3_results[1] + results = None, description + LOGGER.info('TLS 1.2 server test results: ' + str(results)) + return results + + def validate_tls_server(self, host, tls_version): + cert_pem = self.get_public_certificate(host, + validate_cert=False, + tls_version=tls_version) + if cert_pem: + + # Write pem encoding to a file + self.write_cert_to_file(cert_pem) + + # Load pem encoding into a certifiate so we can process the contents + public_cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) + + # Print the certificate information + cert_text = crypto.dump_certificate(crypto.FILETYPE_TEXT, + public_cert).decode() + LOGGER.info('Device Certificate:\n' + cert_text) + + # Validate the certificates time range + tr_valid = self.verify_certificate_timerange(public_cert) + + # Resolve the public key + public_key = self.get_public_key(public_cert) + if public_key: + key_valid = self.verify_public_key(public_key) + + sig_valid = self.validate_signature(host) + + # Check results + cert_valid = tr_valid[0] and key_valid[0] and sig_valid[0] + test_details = tr_valid[1] + '\n' + key_valid[1] + '\n' + sig_valid[1] + LOGGER.info('Certificate validated: ' + str(cert_valid)) + LOGGER.info('Test Details:\n' + test_details) + return cert_valid, test_details + else: + LOGGER.info('Failed to resolve public certificate') + return None, 'Failed to resolve public certificate' + + def write_cert_to_file(self, pem_cert): + with open(self._dev_cert_file, 'w', encoding='UTF-8') as f: + f.write(pem_cert) + + def get_ciphers(self, capture_file, dst_ip, dst_port): + bin_file = self._bin_dir + '/get_ciphers.sh' + args = f'{capture_file} {dst_ip} {dst_port}' + command = f'{bin_file} {args}' + response = util.run_command(command) + ciphers = response[0].split('\n') + return ciphers + + def get_hello_packets(self, capture_file, src_ip, tls_version): + bin_file = self._bin_dir + '/get_client_hello_packets.sh' + args = f'{capture_file} {src_ip} {tls_version}' + command = f'{bin_file} {args}' + response = util.run_command(command) + packets = response[0].strip() + return self.parse_hello_packets(json.loads(packets), capture_file) + + def get_handshake_complete(self, capture_file, src_ip, dst_ip, tls_version): + bin_file = self._bin_dir + '/get_handshake_complete.sh' + args = f'{capture_file} {src_ip} {dst_ip} {tls_version}' + command = f'{bin_file} {args}' + response = util.run_command(command) + return response + + def parse_hello_packets(self, packets, capture_file): + hello_packets = [] + for packet in packets: + # Extract all the basic IP information about the packet + packet_layers = packet['_source']['layers'] + dst_ip = packet_layers['ip.dst'][0] if 'ip.dst' in packet_layers else '' + src_ip = packet_layers['ip.src'][0] if 'ip.src' in packet_layers else '' + dst_port = packet_layers['tcp.dstport'][ + 0] if 'tcp.dstport' in packet_layers else '' + + # Resolve the ciphers used in this packet and validate expected ones exist + ciphers = self.get_ciphers(capture_file, dst_ip, dst_port) + cipher_support = self.is_ecdh_and_ecdsa(ciphers) + + # Put result together + hello_packet = {} + hello_packet['dst_ip'] = dst_ip + hello_packet['src_ip'] = src_ip + hello_packet['dst_port'] = dst_port + hello_packet['cipher_support'] = cipher_support + + hello_packets.append(hello_packet) + return hello_packets + + def process_hello_packets(self,hello_packets, tls_version = '1.2'): + # Validate the ciphers only for tls 1.2 + client_hello_results = {'valid': [], 'invalid': []} + if tls_version == '1.2': + for packet in hello_packets: + if packet['dst_ip'] not in str(client_hello_results['valid']): + LOGGER.info('Checking client ciphers: ' + str(packet)) + if packet['cipher_support']['ecdh'] and packet['cipher_support'][ + 'ecdsa']: + LOGGER.info('Valid ciphers detected') + client_hello_results['valid'].append(packet) + # If a previous hello packet to the same destination failed, + # we can now remove it as it has passed on a different attempt + if packet['dst_ip'] in str(client_hello_results['invalid']): + LOGGER.info(str(client_hello_results['invalid'])) + for invalid_packet in client_hello_results['invalid']: + if packet['dst_ip'] in str(invalid_packet): + client_hello_results['invalid'].remove(invalid_packet) + else: + LOGGER.info('Invalid ciphers detected') + if packet['dst_ip'] not in str(client_hello_results['invalid']): + client_hello_results['invalid'].append(packet) + else: + # No cipher check for TLS 1.3 + client_hello_results['valid'] = hello_packets + return client_hello_results + + def validate_tls_client(self, client_ip, tls_version, capture_file): + LOGGER.info('Validating client for TLS: ' + tls_version) + hello_packets = self.get_hello_packets(capture_file, client_ip, tls_version) + client_hello_results = self.process_hello_packets(hello_packets,tls_version) + + handshakes = {'complete': [], 'incomplete': []} + for packet in client_hello_results['valid']: + # Filter out already tested IP's since only 1 handshake success is needed + if not packet['dst_ip'] in handshakes['complete'] and not packet[ + 'dst_ip'] in handshakes['incomplete']: + handshake_complete = self.get_handshake_complete( + capture_file, packet['src_ip'], packet['dst_ip'], tls_version) + + # One of the responses will be a complaint about running as root so + # we have to have at least 2 entries to consider a completed handshake + if len(handshake_complete) > 1: + LOGGER.info('TLS handshake completed from: ' + packet['dst_ip']) + handshakes['complete'].append(packet['dst_ip']) + else: + LOGGER.warning('No TLS handshakes completed from: ' + + packet['dst_ip']) + handshakes['incomplete'].append(packet['dst_ip']) + + for handshake in handshakes['complete']: + LOGGER.info('Valid TLS client connection to server: ' + str(handshake)) + + # Process and return the results + tls_client_details = '' + tls_client_valid = None + if len(hello_packets) > 0: + if len(client_hello_results['invalid']) > 0: + tls_client_valid = False + for result in client_hello_results['invalid']: + tls_client_details += 'Client hello packet to ' + result[ + 'dst_ip'] + ' did not have expected ciphers:' + if not result['cipher_support']['ecdh']: + tls_client_details += ' ecdh ' + if not result['cipher_support']['ecdsa']: + tls_client_details += 'ecdsa' + tls_client_details += '\n' + if len(handshakes['incomplete']) > 0: + for result in handshakes['incomplete']: + tls_client_details += 'Incomplete handshake detected from server: ' + tls_client_details += result + '\n' + if len(handshakes['complete']) > 0: + # If we haven't already failed the test from previous checks + # allow a passing result + if tls_client_valid is None: + tls_client_valid = True + for result in handshakes['complete']: + tls_client_details += 'Completed handshake detected from server: ' + tls_client_details += result + '\n' + else: + LOGGER.info('No client hello packets detected. Skipping') + tls_client_details = 'No client hello packets detected. Skipping' + return tls_client_valid, tls_client_details + + def is_ecdh_and_ecdsa(self, ciphers): + ecdh = False + ecdsa = False + for cipher in ciphers: + ecdh |= 'ECDH' in cipher + ecdsa |= 'ECDSA' in cipher + return {'ecdh': ecdh, 'ecdsa': ecdsa} diff --git a/modules/test/tls/tls.Dockerfile b/modules/test/tls/tls.Dockerfile new file mode 100644 index 000000000..92fa6028c --- /dev/null +++ b/modules/test/tls/tls.Dockerfile @@ -0,0 +1,48 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Image name: test-run/tls-test +FROM test-run/base-test:latest + +# Set DEBIAN_FRONTEND to noninteractive mode +ENV DEBIAN_FRONTEND=noninteractive + +# Install required software +RUN apt-get update && apt-get install -y tshark + +ARG MODULE_NAME=tls +ARG MODULE_DIR=modules/test/$MODULE_NAME +ARG CERTS_DIR=local/root_certs + +# Copy over all configuration files +COPY $MODULE_DIR/conf /testrun/conf + +# Copy over all binary files +COPY $MODULE_DIR/bin /testrun/bin + +# Copy over all python files +COPY $MODULE_DIR/python /testrun/python + +#Install all python requirements for the module +RUN pip3 install -r /testrun/python/requirements.txt + +# Create a directory inside the container to store the root certificates +RUN mkdir -p /testrun/root_certs + +# Copy over all the local certificates for device signature +# checks if the folder exists +COPY $CERTS_DIR /testrun/root_certs + + + diff --git a/testing/unit/run_tests.sh b/testing/unit/run_tests.sh index 5b1ed6257..5fa1179b1 100644 --- a/testing/unit/run_tests.sh +++ b/testing/unit/run_tests.sh @@ -15,4 +15,8 @@ export PYTHONPATH="$PWD/framework/python/src" python3 -u $PWD/modules/network/dhcp-1/python/src/grpc_server/dhcp_config_test.py python3 -u $PWD/modules/network/dhcp-2/python/src/grpc_server/dhcp_config_test.py +# Run the Security Module Unit Tests +python3 -u $PWD/modules/test/tls/python/src/tls_module_test.py + + popd >/dev/null 2>&1 \ No newline at end of file