Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 36 additions & 35 deletions cyberbattle/simulation/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ class NodeTrackingInformation:
# Map (vulnid, local_or_remote) to time of last attack.
# local_or_remote is true for local, false for remote
last_attack: Dict[Tuple[model.VulnerabilityID, bool], time] = dataclasses.field(default_factory=dict)
# Last time another node connected to this node
last_connection: Optional[time] = None
# Last time the node got owned by the attacker agent
last_owned_at: Optional[time] = None
# All node properties discovered so far
discovered_properties: Set[int] = dataclasses.field(default_factory=set)

Expand Down Expand Up @@ -218,15 +218,26 @@ def __mark_allnodeproperties_as_discovered(self, node_id: model.NodeID):

def __mark_node_as_owned(self,
node_id: model.NodeID,
privilege: PrivilegeLevel = model.PrivilegeLevel.LocalUser) -> None:
if node_id not in self._discovered_nodes:
self._discovered_nodes[node_id] = NodeTrackingInformation()
privilege: PrivilegeLevel = model.PrivilegeLevel.LocalUser) -> Tuple[Optional[time], bool]:
"""Mark a node as owned.
Return the time it was previously own (or None) and whether it was already owned."""
node_info = self._environment.get_node(node_id)
node_info.agent_installed = True
node_info.privilege_level = model.escalate(node_info.privilege_level, privilege)
self._environment.network.nodes[node_id].update({'data': node_info})

self.__mark_allnodeproperties_as_discovered(node_id)
last_owned_at, is_currently_owned = self.__is_node_owned_history(node_id, node_info)

if not is_currently_owned:
if node_id not in self._discovered_nodes:
self._discovered_nodes[node_id] = NodeTrackingInformation()
node_info.agent_installed = True
node_info.privilege_level = model.escalate(node_info.privilege_level, privilege)
self._environment.network.nodes[node_id].update({'data': node_info})

self.__mark_allnodeproperties_as_discovered(node_id)

# Record that the node just got owned at the current time
self._discovered_nodes[node_id].last_owned_at = time()

return last_owned_at, is_currently_owned

def __mark_discovered_entities(self, reference_node: model.NodeID, outcome: model.VulnerabilityOutcome) -> Tuple[int, float, int]:
"""Mark discovered entities as such and return
Expand Down Expand Up @@ -313,28 +324,25 @@ def __process_outcome(self,

reward = 0

was_previously_owned_at, is_currently_owned = self.__is_node_owned_history(node_id, node_info)

# if the vulnerability type is a privilege escalation
# and if the escalation level is not already reached on that node,
# then add the escalation tag to the node properties
if isinstance(outcome, model.PrivilegeEscalation):
if outcome.tag in node_info.properties:
return False, ActionResult(reward=Penalty.REPEAT, outcome=outcome)

if not was_previously_owned_at:
reward += float(node_info.value)
last_owned_at, is_currently_owned = self.__mark_node_as_owned(node_id, outcome.level)

self.__mark_node_as_owned(node_id, outcome.level)
if not last_owned_at:
reward += float(node_info.value)

node_info.properties.append(outcome.tag)

elif isinstance(outcome, model.LateralMove):
if not was_previously_owned_at:
reward += float(node_info.value)
last_owned_at, is_currently_owned = self.__mark_node_as_owned(node_id)

if not is_currently_owned:
self.__mark_node_as_owned(node_id)
if not last_owned_at:
reward += float(node_info.value)

elif isinstance(outcome, model.ProbeSucceeded):
for p in outcome.discovered_properties:
Expand Down Expand Up @@ -454,14 +462,12 @@ def __is_passing_firewall_rules(self, rules: List[model.FirewallRule], port_name
return False

def __is_node_owned_history(self, target_node_id, target_node_data):
""" Returns whether the node was previously owned and if it's still currently owned."""
was_previously_owned_at = self._discovered_nodes[target_node_id].last_connection
self._discovered_nodes[target_node_id].last_connection = time()
""" Returns the last time the node got owned and whether it is still currently owned."""
last_previously_owned_at = self._discovered_nodes[target_node_id].last_owned_at if target_node_id in self._discovered_nodes else None

is_currently_owned = was_previously_owned_at is not None and \
target_node_data.last_reimaging is not None and \
was_previously_owned_at >= target_node_data.last_reimaging
return was_previously_owned_at, is_currently_owned
is_currently_owned = last_previously_owned_at is not None and \
(target_node_data.last_reimaging is None or last_previously_owned_at >= target_node_data.last_reimaging)
return last_previously_owned_at, is_currently_owned

def connect_to_remote_machine(
self,
Expand Down Expand Up @@ -525,27 +531,22 @@ def connect_to_remote_machine(
return ActionResult(reward=Penalty.WRONG_PASSWORD,
outcome=None)

is_already_owned = target_node_data.agent_installed
last_owned_at, is_already_owned = self.__mark_node_as_owned(target_node_id)

if is_already_owned:
return ActionResult(reward=Penalty.REPEAT,
outcome=model.LateralMove())
return ActionResult(reward=Penalty.REPEAT, outcome=model.LateralMove())

if target_node_id not in self._discovered_nodes:
self._discovered_nodes[target_node_id] = NodeTrackingInformation()

was_previously_owned_at, is_currently_owned = self.__is_node_owned_history(target_node_id, target_node_data)

if is_currently_owned:
return ActionResult(reward=Penalty.REPEAT, outcome=model.LateralMove())

self.__annotate_edge(source_node_id, target_node_id, EdgeAnnotation.LATERAL_MOVE)
self.__mark_node_as_owned(target_node_id)

logger.info(f"Infected node '{target_node_id}' from '{source_node_id}'" +
f" via {port_name} with credential '{credential}'")
if target_node.owned_string:
logger.info("Owned message: " + target_node.owned_string)

return ActionResult(reward=float(target_node_data.value) if was_previously_owned_at is None else 0.0,
return ActionResult(reward=float(target_node_data.value) if last_owned_at is None else 0.0,
outcome=model.LateralMove())

def _check_service_running_and_authorized(self,
Expand Down
2 changes: 1 addition & 1 deletion cyberbattle/simulation/actions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def test_exploit_remote_vulnerability(actions_on_simple_environment: Fixture) ->
# test a valid and functional one.
result = actions_on_simple_environment.exploit_remote_vulnerability('a', 'c', "RDPBF")
assert isinstance(result.outcome, model.LateralMove)
assert result.reward >= node.value
assert result.reward < node.value


def test_exploit_local_vulnerability(actions_on_simple_environment: Fixture) -> None:
Expand Down
2 changes: 1 addition & 1 deletion cyberbattle/simulation/commandcontrol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,6 @@ def test_toyctf() -> None:

reward = command.total_reward()
print('Total reward ' + str(reward))
assert reward == 289.0
assert reward == 389.0
assert github is not None
pass