Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions modules/test/tls/python/src/tls_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ def generate_module_report(self):
if len(certificates) > 0:

cert_tables = []
for cert_num, ((ip_address, port), # pylint: disable=W0612
cert) in enumerate(certificates.items()):
# pylint: disable=W0612
for cert_num, (
(ip_address, port),
cert) in enumerate(certificates.items()):

# Add summary table
summary_table = '''
Expand Down Expand Up @@ -246,7 +248,13 @@ def generate_module_report(self):

cert_tables.append(summary_table)

html_content += '\n'.join('\n' + tables for tables in cert_tables)
outbound_conns = self._tls_util.get_all_outbound_connections(
device_mac=self._device_mac, capture_files=pcap_files)
conn_table = self.generate_outbound_connection_table(outbound_conns)

html_content += summary_table + '\n'.join('\n' + tables
for tables in cert_tables)
html_content += conn_table

else:
html_content += ('''
Expand Down Expand Up @@ -316,6 +324,34 @@ def format_extension_value(self, value):
f'crl_sign={value.crl_sign}')
return str(value) # Fallback to string conversion

def generate_outbound_connection_table(self, outbound_conns):
"""Generate just an HTML table from a list of IPs"""
html_content = '''
<h1>Outbound Connections</h1>
<table class="module-data">
<thead>
<tr>
<th>Destination IP</th>
<th>Port</th>
</tr>
</thead>
<tbody>
'''

rows = [
f'\t<tr><td>{ip}</td><td>{port}</td></tr>'
for ip, port in outbound_conns
]
html_content += '\n'.join(rows)

# Close the table
html_content += """
</tbody>
\r</table>
"""

return html_content

def extract_certificates_from_pcap(self, pcap_files, mac_address):
# Initialize a list to store packets
all_packets = []
Expand Down
47 changes: 47 additions & 0 deletions modules/test/tls/python/src/tls_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from ipaddress import IPv4Address
from scapy.all import rdpcap, IP, Ether, TCP, UDP

LOG_NAME = 'tls_util'
LOGGER = None
Expand All @@ -37,6 +38,7 @@
ipaddress.ip_network('172.16.0.0/12'),
ipaddress.ip_network('192.168.0.0/16')
]
TR_CONTAINER_MAC_PREFIX = '9a:02:57:1e:8f:'
#Define the allowed protocols as tshark filters
DEFAULT_ALLOWED_PROTOCOLS = ['quic']

Expand All @@ -59,6 +61,51 @@ def __init__(self,
if allowed_protocols is None:
self._allowed_protocols = DEFAULT_ALLOWED_PROTOCOLS

def get_all_outbound_connections(self, device_mac, capture_files):
"""Process multiple pcap files and combine unique IP destinations."""

outbound_conns = set()
for capture in capture_files:
ips = self.get_outbound_connections(device_mac=device_mac,
capture_file=capture)
outbound_conns.update(ips)
return list(outbound_conns)

def get_outbound_connections(self, device_mac, capture_file):
"""Extract unique IP and port destinations from a single pcap file
based on the known MAC address."""
packets = rdpcap(capture_file)
outbound_conns = set()
for packet in packets:
if Ether in packet and IP in packet:
if packet[Ether].src == device_mac:
ip_dst = packet[IP].dst
port_dst = 'Unknown'

# Check if the packet has TCP or UDP layer to get the destination port
if TCP in packet:
port_dst = packet[TCP].dport
elif UDP in packet:
port_dst = packet[UDP].dport

if self.is_external_ip(ip_dst):
outbound_conns.add((ip_dst, port_dst))

return outbound_conns

def is_external_ip(self, ip):
"""Check if the IP is an external (non-private) IP address."""
try:
# Convert the IP string into an IPv4Address object
ip_addr = ipaddress.ip_address(ip)

# Return True only if the IP is not in a private or reserved range
return not (ip_addr.is_private or ip_addr.is_loopback
or ip_addr.is_link_local)
except ValueError:
# Return False if the IP is invalid or not IPv4
return False

def get_public_certificate(self,
host,
port=443,
Expand Down
27 changes: 26 additions & 1 deletion testing/unit/tls/tls_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module run all the TLS related unit tests"""
from tls_util import TLSUtil
from tls_module import TLSModule
from tls_util import TLSUtil
import os
import unittest
from common import logger
Expand Down Expand Up @@ -321,6 +321,28 @@ def security_tls_client_allowed_protocols_test(self):
print(str(test_results))
self.assertTrue(test_results[0])

def outbound_connections_test(self):
""" Test generation of the outbound connection ips"""
print('\noutbound_connections_test')
capture_file = os.path.join(CAPTURES_DIR, 'monitor.pcap')
ip_dst = TLS_UTIL.get_all_outbound_connections(
device_mac='70:b3:d5:96:c0:00', capture_files=[capture_file])
print(str(ip_dst))
# Compare as sets since returned order is not guaranteed
self.assertEqual(
set(ip_dst),
set(['8.8.8.8', '224.0.0.22', '18.140.82.197', '216.239.35.0']))

def outbound_connections_report_test(self):
""" Test generation of the outbound connection ips"""
print('\noutbound_connections_report_test')
capture_file = os.path.join(CAPTURES_DIR, 'monitor.pcap')
ip_dst = TLS_UTIL.get_all_outbound_connections(
device_mac='70:b3:d5:96:c0:00', capture_files=[capture_file])
tls = TLSModule(module=MODULE)
gen_html = tls.generate_outbound_connection_table(ip_dst)
print(gen_html)

def tls_module_report_test(self):
print('\ntls_module_report_test')
os.environ['DEVICE_MAC'] = '38:d1:35:01:17:fe'
Expand Down Expand Up @@ -583,6 +605,9 @@ def download_public_cert(self, hostname, port=443):

# suite.addTest(TLSModuleTest('security_tls_client_allowed_protocols_test'))

suite.addTest(TLSModuleTest('outbound_connections_test'))
suite.addTest(TLSModuleTest('outbound_connections_report_test'))

runner = unittest.TextTestRunner()
test_result = runner.run(suite)

Expand Down