From c3dc281e0ae6cd33276db7abf5853d11be23a4d0 Mon Sep 17 00:00:00 2001 From: Vladislav Kondratyev Date: Fri, 15 Aug 2025 12:44:59 -0500 Subject: [PATCH 01/21] improved systems --- README.md | 52 ++++++++++ src/sentinel/__init__.py | 21 +++- src/sentinel/score_formulae.py | 143 +++++++++++++++++++++++---- src/sentinel/score_types.py | 50 ++++++---- src/sentinel/sentinel_local_index.py | 68 +++++++++++-- tests/conftest.py | 9 ++ tests/test_score_formulae.py | 25 +++++ tests/test_sriracha_local_index.py | 10 ++ 8 files changed, 331 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index 9150d9b..70b184b 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,22 @@ Roblox Sentinel, part of the Roblox Safety Toolkit, is a Python library designed By prioritizing recall over precision, Sentinel serves as a high-recall candidate generator for more thorough investigation. This approach is particularly effective for applications where rare patterns are critical to identify. Rather than treating each message in isolation, Sentinel analyzes patterns across messages to identify concerning behavior. +## What’s New: Aggregation options and Explainability + +Sentinel now includes multiple aggregation strategies and built‑in explainability to help you tune for your use case and understand why a score was assigned. + +- Aggregators (in `sentinel.score_formulae`): + - `skewness(scores, min_size_of_scores=10)`: default, pattern‑oriented and robust to message count + - `top_k_mean(scores, k=3)`: focuses on the strongest signals + - `percentile_score(scores, q=90.0)`: robust to outliers via a percentile over positives + - `softmax_weighted_mean(scores, temperature=1.0)`: smoothly emphasizes higher scores + - `max_score(scores)`: simplest, picks the highest positive score + +- Explainability (in results): + - Each call to `calculate_rare_class_affinity` returns a `RareClassAffinityResult` with: + - `aggregation_name`, `aggregation_stats`: which aggregator was used and key params + - `explanations`: per‑text details including top‑K positive/negative similarities, contrastive components, and neighbor snippets (when available) + ## Terminology In Sentinel's codebase: @@ -65,6 +81,16 @@ print(f"Overall rare class affinity score: {overall_score:.4f}") for message, score in result.observation_scores.items(): risk_level = "High" if score > 0.5 else "Medium" if score > 0.1 else "Low" print(f"'{message}' - Score: {score:.4f} - Risk: {risk_level}") + +# Inspect explainability +print("Aggregator:", result.aggregation_name) +print("Aggregation stats:", result.aggregation_stats) +for message, ex in result.explanations.items(): + print("--", message) + print(" topk_positive:", ex["topk_positive"]) # scaled similarities + print(" topk_negative:", ex["topk_negative"]) # scaled similarities + print(" contrastive:", ex["contrastive"]) # positive_term, negative_term, log_ratio_unclipped + print(" neighbors (sample):", ex["neighbors"][:2] if ex["neighbors"] else None) ``` ## Creating a New Index @@ -109,6 +135,32 @@ saved_config = index.save( aws_access_key_id="YOUR_ACCESS_KEY_ID", # Optional if using environment credentials aws_secret_access_key="YOUR_SECRET_ACCESS_KEY" # Optional if using environment credentials ) + +## Choosing an aggregation strategy + +Different deployments optimize for different trade‑offs. You can swap in any aggregator using the `aggregation_function` argument: + +```python +from sentinel.score_formulae import top_k_mean, percentile_score, softmax_weighted_mean, max_score + +texts = ["msg a", "msg b", "msg c"] + +# Focus on the strongest few signals +res1 = index.calculate_rare_class_affinity(texts, aggregation_function=lambda arr: top_k_mean(arr, k=3)) + +# Robust to outliers +res2 = index.calculate_rare_class_affinity(texts, aggregation_function=lambda arr: percentile_score(arr, q=90)) + +# Smoothly emphasize higher scores +res3 = index.calculate_rare_class_affinity(texts, aggregation_function=lambda arr: softmax_weighted_mean(arr, temperature=0.5)) + +# Simplest, picks the maximum +res4 = index.calculate_rare_class_affinity(texts, aggregation_function=max_score) +``` + +Notes: +- All aggregators operate over per‑observation scores where non‑confident observations are already clipped to 0. +- The default `skewness` remains a good choice when user activity volume varies widely. ``` ## How It Works diff --git a/src/sentinel/__init__.py b/src/sentinel/__init__.py index 2973bea..dca2c35 100644 --- a/src/sentinel/__init__.py +++ b/src/sentinel/__init__.py @@ -19,6 +19,23 @@ """ from sentinel.sentinel_local_index import SentinelLocalIndex -from sentinel.score_formulae import calculate_contrastive_score +from sentinel.score_formulae import ( + calculate_contrastive_score, + skewness, + mean_of_positives, + top_k_mean, + percentile_score, + softmax_weighted_mean, + max_score, +) -__all__ = ["SentinelLocalIndex", "calculate_contrastive_score"] +__all__ = [ + "SentinelLocalIndex", + "calculate_contrastive_score", + "skewness", + "mean_of_positives", + "top_k_mean", + "percentile_score", + "softmax_weighted_mean", + "max_score", +] diff --git a/src/sentinel/score_formulae.py b/src/sentinel/score_formulae.py index bb59bb5..90e709f 100644 --- a/src/sentinel/score_formulae.py +++ b/src/sentinel/score_formulae.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Score calculation functions for Sentinel index.""" +"""Score calculation functions for Sentinel index. + +This module contains per-observation scoring utilities (contrastive scoring) +and aggregation functions to combine multiple observation scores into a single +affinity number. In addition to the default skewness, a set of robust +alternatives are provided to fit different deployment preferences (recall vs precision, +stability vs sensitivity, etc.). +""" import numpy as np from typing import List, Callable @@ -70,6 +77,117 @@ def skewness(scores: np.array, min_size_of_scores: int = 10) -> float: return (mean - median) / std +def top_k_mean(scores: np.array, k: int = 3) -> float: + """Mean of the top-k positive scores. + + Focuses on the strongest signals while ignoring noise and negatives. + + Args: + scores: Array of observation scores. + k: Number of highest positive scores to average. + + Returns: + Mean of the top-k positive scores (0.0 if no positive scores). + """ + if scores.size == 0: + return 0.0 + positives = scores[scores > 0] + if positives.size == 0: + return 0.0 + k = min(k, positives.size) + # Use partition for efficiency, then mean of the largest k + idx = np.argpartition(positives, -k)[-k:] + return float(np.mean(positives[idx])) + + +def percentile_score(scores: np.array, q: float = 90.0) -> float: + """Return the q-th percentile among positive scores (robust to outliers). + + Args: + scores: Array of observation scores. + q: Percentile in [0, 100]. + + Returns: + q-th percentile of positive scores (0.0 if no positive scores). + """ + if scores.size == 0: + return 0.0 + positives = scores[scores > 0] + if positives.size == 0: + return 0.0 + return float(np.percentile(positives, q)) + + +def softmax_weighted_mean(scores: np.array, temperature: float = 1.0) -> float: + """Softmax-weighted mean over positive scores. + + Emphasizes higher scores while keeping some contribution from smaller ones. + + Args: + scores: Array of observation scores. + temperature: Softmax temperature (>0). Lower values emphasize peaks more. + + Returns: + Softmax-weighted average of positive scores (0.0 if no positive scores). + """ + if scores.size == 0: + return 0.0 + positives = scores[scores > 0] + if positives.size == 0: + return 0.0 + t = max(1e-6, float(temperature)) + x = positives / t + # Numerical stability + x = x - np.max(x) + w = np.exp(x) + w = w / np.sum(w) + return float(np.sum(w * positives)) + + +def max_score(scores: np.array) -> float: + """Maximum positive score (simple, sensitive, and easy to interpret).""" + if scores.size == 0: + return 0.0 + positives = scores[scores > 0] + if positives.size == 0: + return 0.0 + return float(np.max(positives)) + + +def contrastive_components( + similarities_topk_pos: List[float], + similarities_topk_neg: List[float], + aggregation_fn: Callable[[np.array], float] = np.mean, +): + """Return contrastive components and final log-ratio for a single observation. + + Computes the positive and negative terms used by the contrastive score and + the unclipped log ratio. Useful for explainability. + + Returns: + (positives_term, negatives_term, log_ratio) + """ + if len(similarities_topk_pos) <= 0 or len(similarities_topk_neg) <= 0: + raise ValueError( + "The lists of similarities must have at least one element each." + ) + + similarities_topk_pos = np.array(similarities_topk_pos) + similarities_topk_neg = np.array(similarities_topk_neg) + + positives_term = aggregation_fn(np.exp(similarities_topk_pos)) + negatives_term = aggregation_fn(np.exp(similarities_topk_neg)) + + # Avoid divide-by-zero (shouldn’t happen with exp, but be safe) + if negatives_term == 0: + log_ratio = np.inf + else: + ratio = positives_term / negatives_term + log_ratio = np.log(ratio) + + return float(positives_term), float(negatives_term), float(log_ratio) + + def calculate_contrastive_score( similarities_topk_pos: List[float], similarities_topk_neg: List[float], @@ -94,19 +212,10 @@ def calculate_contrastive_score( Returns: A contrastive score where values > 0 indicate closer similarity to rare class content """ - if len(similarities_topk_pos) <= 0 or len(similarities_topk_neg) <= 0: - raise ValueError( - "The lists of similarities must have at least one element each." - ) - - similarities_topk_pos = np.array(similarities_topk_pos) - similarities_topk_neg = np.array(similarities_topk_neg) - - positives_term = aggregation_fn(np.exp(similarities_topk_pos)) - negatives_term = aggregation_fn(np.exp(similarities_topk_neg)) - - contrastive_score = positives_term / negatives_term - - if contrastive_score <= 1.0: - return 0 # Clip to zero to avoid negative scores, since we accumulate this score for all chat lines of a user. - return np.log(contrastive_score) + positives_term, negatives_term, log_ratio = contrastive_components( + similarities_topk_pos, similarities_topk_neg, aggregation_fn + ) + # Clip to zero to avoid negative scores, since we accumulate this score for all chat lines of a user. + if log_ratio <= 0.0: + return 0.0 + return float(log_ratio) diff --git a/src/sentinel/score_types.py b/src/sentinel/score_types.py index 9cfe4bb..68fba06 100644 --- a/src/sentinel/score_types.py +++ b/src/sentinel/score_types.py @@ -15,28 +15,36 @@ """Data types for rare class detection and scoring.""" from dataclasses import dataclass -from typing import Dict +from typing import Dict, Optional, Any @dataclass class RareClassAffinityResult: - """Result of calculating affinity to a rare class of text. - - This class contains both: - 1. The overall rare_class_affinity_score for a collection of texts, which is used to prioritize - cases for further investigation in a realtime context - 2. The individual observation_scores for each text, which can be used to identify which specific - observations contributed most to the overall pattern - - As a high-recall candidate generator, this result helps identify potential instances of rare - classes that warrant closer examination, prioritizing not missing true positives even at the - cost of some false positives. - - Attributes: - rare_class_affinity_score: The aggregated score indicating overall affinity to the rare class, - typically calculated using skewness to identify patterns - observation_scores: Dictionary mapping each input text to its individual similarity score - """ - - rare_class_affinity_score: float - observation_scores: Dict[str, float] + """Result of calculating affinity to a rare class of text. + + This class contains both: + 1. The overall rare_class_affinity_score for a collection of texts, which is used to prioritize + cases for further investigation in a realtime context. + 2. The individual observation_scores for each text, which can be used to identify which specific + observations contributed most to the overall pattern. + + As a high-recall candidate generator, this result helps identify potential instances of rare + classes that warrant closer examination, prioritizing not missing true positives even at the + cost of some false positives. + + Attributes: + rare_class_affinity_score: The aggregated score indicating overall affinity to the rare class, + typically calculated using skewness to identify patterns. + observation_scores: Mapping of input text to its individual similarity score. + aggregation_name: Optional name of the aggregation function used. + aggregation_stats: Optional dictionary with aggregation-relevant statistics + (e.g. top_k, percentile, temperature, num_positives). + explanations: Optional per-text explainability records describing which neighbors and + components contributed to each score. + """ + + rare_class_affinity_score: float + observation_scores: Dict[str, float] + aggregation_name: Optional[str] = None + aggregation_stats: Optional[Dict[str, Any]] = None + explanations: Optional[Dict[str, Any]] = None diff --git a/src/sentinel/sentinel_local_index.py b/src/sentinel/sentinel_local_index.py index 8730bfb..42a1660 100644 --- a/src/sentinel/sentinel_local_index.py +++ b/src/sentinel/sentinel_local_index.py @@ -26,7 +26,7 @@ from sentence_transformers import SentenceTransformer from sentence_transformers.util import semantic_search -from sentinel.score_formulae import calculate_contrastive_score, skewness +from sentinel.score_formulae import calculate_contrastive_score, skewness, contrastive_components from sentinel.io.saved_index_config import SavedIndexConfig from sentinel.io.index_io import save_index, load_index, create_s3_transport_params from sentinel.embeddings.sbert import get_sentence_transformer_and_scaling_fn @@ -253,17 +253,15 @@ def calculate_rare_class_affinity( self, text_samples: List[str], top_k: int = 5, - similarity_formula: Callable[ - [List[float], List[float]], float - ] = calculate_contrastive_score, + similarity_formula: Callable[[List[float], List[float]], float] = calculate_contrastive_score, # Function to aggregate individual scores into an overall affinity score aggregation_function: Callable[[np.array], float] = skewness, # Margin to ignore when text is only slightly more similar to positive than negative. min_score_to_consider: float = 0.1, # Use when simulating by sampling texts from the same data indexed. - prevent_exact_match: bool = False, - encoding_additional_kwargs: Mapping[str, Any] = {}, - show_progress_bar: bool = False, + prevent_exact_match: bool = False, + encoding_additional_kwargs: Mapping[str, Any] = {}, + show_progress_bar: bool = False, ) -> RareClassAffinityResult: """Calculate rare class affinity for the given text samples in realtime. @@ -320,7 +318,13 @@ def calculate_rare_class_affinity( top_k=top_k + additional_neighbors, ) + # Explainability defaults (always on for transparency) + explain = True + include_neighbors = True + neighbors_limit = 5 + observation_scores = {} + explanations = {} if explain else None for i, q in enumerate(text_samples): LOG.debug("Query: %s", q) @@ -340,6 +344,7 @@ def calculate_rare_class_affinity( similarities_topk_positive = [] similarities_topk_negative = [] max_h = top_k # Number of examples to consider + neighbor_records = [] if include_neighbors else None # Process each match in order of similarity (highest first) for h, (score, corpus_id, sign) in enumerate(matches): @@ -381,6 +386,22 @@ def calculate_rare_class_affinity( f"[{sign}] {neighbor} (Score: {score:.4f}, Scaled Score: {scaled_score:.4f})" ) + if include_neighbors and len(neighbor_records) < neighbors_limit: + # Keep a compact neighbor record for explainability + try: + corpus_id_int = int(corpus_id) + except Exception: + corpus_id_int = int(corpus_id) if isinstance(corpus_id, (int, np.integer)) else 0 + neighbor_records.append( + { + "sign": "+" if sign == "+" else "-", + "raw_score": float(score), + "scaled_score": float(scaled_score), + "neighbor": neighbor, + "corpus_id": corpus_id_int, + } + ) + # Ensure we have at least one similarity value for each category # If we didn't find any of a particular category in the top matches, # use the first match from the original search @@ -409,6 +430,25 @@ def calculate_rare_class_affinity( else: observation_scores[q] = score + # Per-text explainability + if explain: + pos_term, neg_term, log_ratio = contrastive_components( + similarities_topk_pos=similarities_topk_positive, + similarities_topk_neg=similarities_topk_negative, + ) + explanations[q] = { + "topk_positive": [float(x) for x in similarities_topk_positive], + "topk_negative": [float(x) for x in similarities_topk_negative], + "contrastive": { + "positive_term": pos_term, + "negative_term": neg_term, + "log_ratio_unclipped": log_ratio, + }, + "neighbors": neighbor_records[:neighbors_limit] + if include_neighbors and neighbor_records is not None + else None, + } + # Calculate the overall rare class affinity score by aggregating individual scores # If there are no scores, default to 0.0 if not observation_scores: @@ -418,7 +458,21 @@ def calculate_rare_class_affinity( np.array(list(observation_scores.values())) ) + # Aggregation metadata for explainability + agg_name = getattr(aggregation_function, "__name__", str(aggregation_function)) + agg_stats = { + "num_texts": len(text_samples), + "num_positive_scores": int( + np.sum(np.array(list(observation_scores.values())) > 0) + ), + "top_k_per_observation": top_k, + "min_score_to_consider": float(min_score_to_consider), + } + return RareClassAffinityResult( rare_class_affinity_score=rare_class_score, observation_scores=observation_scores, + aggregation_name=agg_name, + aggregation_stats=agg_stats, + explanations=explanations if explain else None, ) diff --git a/tests/conftest.py b/tests/conftest.py index c578a2f..7ee8517 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,8 @@ """Shared test fixtures and configurations for Sentinel tests.""" import os +import sys +import pathlib import pytest import numpy as np import torch @@ -22,6 +24,13 @@ from unittest.mock import MagicMock +# Ensure the package under src/ is importable without installation +_REPO_ROOT = pathlib.Path(__file__).resolve().parents[1] +_SRC_PATH = _REPO_ROOT / "src" +if str(_SRC_PATH) not in sys.path: + sys.path.insert(0, str(_SRC_PATH)) + + # Set up logging for tests @pytest.fixture(autouse=True) def setup_logging(): diff --git a/tests/test_score_formulae.py b/tests/test_score_formulae.py index 0115bc0..214124a 100644 --- a/tests/test_score_formulae.py +++ b/tests/test_score_formulae.py @@ -21,6 +21,10 @@ mean_of_positives, calculate_contrastive_score, skewness, + top_k_mean, + percentile_score, + softmax_weighted_mean, + max_score, ) @@ -125,3 +129,24 @@ def test_skewness(): empty_scores = np.array([]) result = skewness(empty_scores) assert np.isclose(result, 0.0), "Should return 0.0 for empty array" + + +def test_additional_aggregators(): + scores = np.array([0.0, 0.2, 0.5, 1.0, 0.7, -0.1, 0.3]) + + # top_k_mean + val = top_k_mean(scores, k=2) + assert np.isclose(val, np.mean([1.0, 0.7])) + + # percentile_score + val = percentile_score(scores, q=50) + # positives are [0.2, 0.5, 1.0, 0.7, 0.3]; median = 0.5 + assert np.isclose(val, 0.5) + + # softmax_weighted_mean (temperature=1) + val = softmax_weighted_mean(scores, temperature=1.0) + assert val > 0.5 and val <= 1.0 + + # max_score + val = max_score(scores) + assert np.isclose(val, 1.0) diff --git a/tests/test_sriracha_local_index.py b/tests/test_sriracha_local_index.py index a34f769..7e95f7d 100644 --- a/tests/test_sriracha_local_index.py +++ b/tests/test_sriracha_local_index.py @@ -168,6 +168,16 @@ def test_calculate_rare_class_affinity(self, simple_index): ) assert all(score == 0.0 for score in result.observation_scores.values()) + # Explainability fields present + assert result.aggregation_name is not None + assert isinstance(result.aggregation_stats, dict) + assert result.explanations is not None + # Each input has an explanation + for t in mixed_text: + assert t in result.explanations + ex = result.explanations[t] + assert "topk_positive" in ex and "topk_negative" in ex and "contrastive" in ex + # Integration test combining various components @pytest.mark.integration From 4ba579010821c825664b5c19950630a50fa75d0a Mon Sep 17 00:00:00 2001 From: rafainn Date: Sun, 17 Aug 2025 22:45:31 +0100 Subject: [PATCH 02/21] FIX: Refactor negative_to_positive_ratio parameter to be optional in SentinelLocalIndex. FEAT: created a testing tool for best threshold and ratio analysis --- examples/Example_Threshold_Script.py | 215 +++++++++++++++++++++++++++ src/sentinel/sentinel_local_index.py | 17 ++- 2 files changed, 230 insertions(+), 2 deletions(-) create mode 100644 examples/Example_Threshold_Script.py diff --git a/examples/Example_Threshold_Script.py b/examples/Example_Threshold_Script.py new file mode 100644 index 0000000..b0bf48e --- /dev/null +++ b/examples/Example_Threshold_Script.py @@ -0,0 +1,215 @@ +""" +Comprehensive Testing Dataset for Sentinel Hate Speech Detection +================================================================ + +This module creates realistic user profiles with different speech patterns +to test threshold behavior and aggregation performance. + +User Types: +- Normal Speech Only (3 users) +- Hate Speech Focused (2 users) +- Sexual Content Focused (2 users) +- Mixed Content (2 users) +- All Types Combined (1 user) +""" + +from sentinel.sentinel_local_index import SentinelLocalIndex +import numpy as np +from typing import Dict, List, Tuple + +def create_user_profiles() -> Dict[str, List[str]]: + """Create 10 different user profiles with varying speech patterns.""" + + # Normal speech examples + normal_speech = [ + "Hey how are you doing today?", + "What are you up to this weekend?", + "I love this new video game", + "The weather is really nice today", + "Did you see that movie last night?", + "I'm going to the store later", + "Work was pretty busy today", + "Let's grab lunch sometime", + "That's a really cool picture", + "I just finished reading a great book", + "How was your vacation?", + "I'm excited for the weekend", + "Thanks for helping me out", + "Have a great day!", + "What's your favorite food?", + "I'm learning to cook pasta", + "The concert was amazing", + "I need to go shopping", + "My family is visiting next week", + "That joke was hilarious" + ] + + # Hate speech examples + hate_speech = [ + "I hate all those people", + "they don't belong here", + "those immigrants are ruining everything", + "whites are superior to everyone else", + "the holocaust never happened", + "Hitler had the right idea", + "Jews control all the banks", + "Black people are all criminals", + "Muslims are all terrorists", + "gay people are disgusting", + "women belong in the kitchen", + "trans people are mentally ill", + "minorities are destroying our country", + "we need to get rid of them all", + "violence is the only solution" + ] + + # Sexual content examples + sexual_content = [ + "want to see me naked?", + "send me some hot pics", + "let's hook up tonight", + "I'm so horny right now", + "what are you wearing?", + "wanna have some fun?", + "you look so sexy", + "let's get physical", + "I want to touch you", + "meet me in private", + "show me your body", + "let's talk dirty", + "I'm thinking about you", + "we could be alone together", + "send nudes" + ] + + users = {} + + # Normal Speech Only Users (3 users) + for i in range(1, 4): + users[f"normal_user_{i}"] = np.random.choice(normal_speech, size=15, replace=False).tolist() + + # Hate Speech Focused Users (2 users) + for i in range(1, 3): + hate_msgs = np.random.choice(hate_speech, size=10, replace=False).tolist() + normal_msgs = np.random.choice(normal_speech, size=5, replace=False).tolist() + users[f"hate_user_{i}"] = hate_msgs + normal_msgs + np.random.shuffle(users[f"hate_user_{i}"]) + + # Sexual Content Focused Users (2 users) + for i in range(1, 3): + sexual_msgs = np.random.choice(sexual_content, size=10, replace=False).tolist() + normal_msgs = np.random.choice(normal_speech, size=5, replace=False).tolist() + users[f"sexual_user_{i}"] = sexual_msgs + normal_msgs + np.random.shuffle(users[f"sexual_user_{i}"]) + + # Mixed Content Users (2 users) + for i in range(1, 3): + hate_msgs = np.random.choice(hate_speech, size=5, replace=False).tolist() + sexual_msgs = np.random.choice(sexual_content, size=5, replace=False).tolist() + normal_msgs = np.random.choice(normal_speech, size=5, replace=False).tolist() + users[f"mixed_user_{i}"] = hate_msgs + sexual_msgs + normal_msgs + np.random.shuffle(users[f"mixed_user_{i}"]) + + # All Types Combined User (1 user) + hate_msgs = np.random.choice(hate_speech, size=7, replace=False).tolist() + sexual_msgs = np.random.choice(sexual_content, size=7, replace=False).tolist() + normal_msgs = np.random.choice(normal_speech, size=6, replace=False).tolist() + users["all_types_user"] = hate_msgs + sexual_msgs + normal_msgs + np.random.shuffle(users["all_types_user"]) + + return users + +def test_thresholds_and_ratios(): + """Test different threshold and ratio combinations.""" + + print("🧪 COMPREHENSIVE SENTINEL TESTING") + print("=" * 50) + + # Load index with different ratios + ratios_to_test = [10.0, 5.0, 1.0] + thresholds_to_test = [0.0, 0.01, 0.05, 0.1] + + users = create_user_profiles() + + for ratio in ratios_to_test: + ratio_name = "Original" if ratio is None else f"{ratio}:1" + print(f"\n📊 TESTING RATIO: {ratio_name}") + print("-" * 30) + + # Load index with specific ratio + index = SentinelLocalIndex.load( + path="path/to/local/index", + negative_to_positive_ratio=ratio + ) + + # Or load from S3 + + # index = SentinelLocalIndex.load( + # path="s3://my-bucket/path/to/index", + # aws_access_key_id="YOUR_ACCESS_KEY_ID", # Optional if using environment credentials + # aws_secret_access_key="YOUR_SECRET_ACCESS_KEY", # Optional if using environment credentials + # negative_to_positive_ratio=ratio + # ) + + print(f"Loaded shapes: pos={index.positive_embeddings.shape[0]}, neg={index.negative_embeddings.shape[0]}") + + for threshold in thresholds_to_test: + print(f"\n🎯 Threshold: {threshold}") + + results = {} + + # Test each user + for user_name, messages in users.items(): + result = index.calculate_rare_class_affinity( + messages, + min_score_to_consider=threshold + ) + + # Calculate statistics + positive_scores = [score for score in result.observation_scores.values() if score > 0] + results[user_name] = { + 'overall_score': result.rare_class_affinity_score, + 'positive_count': len(positive_scores), + 'max_score': max(result.observation_scores.values()) if result.observation_scores.values() else 0, + 'avg_positive': np.mean(positive_scores) if positive_scores else 0 + } + + # Categorize and display results + categories = { + 'Normal Users': [k for k in results.keys() if k.startswith('normal_')], + 'Hate Users': [k for k in results.keys() if k.startswith('hate_')], + 'Sexual Users': [k for k in results.keys() if k.startswith('sexual_')], + 'Mixed Users': [k for k in results.keys() if k.startswith('mixed_')], + 'All Types': [k for k in results.keys() if k.startswith('all_types')] + } + + for category, user_list in categories.items(): + if not user_list: + continue + + scores = [results[user]['overall_score'] for user in user_list] + detections = [results[user]['positive_count'] for user in user_list] + + print(f" {category:12}: avg_score={np.mean(scores):.4f}, avg_detections={np.mean(detections):.1f}") + + # Summary statistics + normal_scores = [results[u]['overall_score'] for u in categories['Normal Users']] + problematic_scores = [results[u]['overall_score'] for u in + categories['Hate Users'] + categories['Sexual Users'] + + categories['Mixed Users'] + categories['All Types']] + + if normal_scores and problematic_scores: + separation = np.mean(problematic_scores) - np.mean(normal_scores) + print(f" 📈 Separation: {separation:.4f} (higher is better)") + +def main(): + """Run the comprehensive testing.""" + # Set random seed for reproducible results + np.random.seed(42) + + test_thresholds_and_ratios() + + print(f"\n✅ Testing complete! Check results above to determine optimal threshold and ratio.") + +if __name__ == "__main__": + main() diff --git a/src/sentinel/sentinel_local_index.py b/src/sentinel/sentinel_local_index.py index 42a1660..a2da002 100644 --- a/src/sentinel/sentinel_local_index.py +++ b/src/sentinel/sentinel_local_index.py @@ -175,7 +175,7 @@ def load( path: str, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, - negative_to_positive_ratio: float = 1.0, + negative_to_positive_ratio: Optional[float] = None, ) -> "SentinelLocalIndex": """ Load the index from a path and returns a new SentinelLocalIndex instance. @@ -185,6 +185,8 @@ def load( aws_access_key_id: Optional AWS access key ID for S3 access. aws_secret_access_key: Optional AWS secret access key for S3 access. negative_to_positive_ratio: Ratio of negative examples to keep relative to positive examples. + If None (default), preserves the original ratio from the saved index. + If specified, downsamples negative examples to achieve the desired ratio. Returns: A new SentinelLocalIndex instance with the loaded model and embeddings. @@ -219,13 +221,24 @@ def load( return instance - def _apply_negative_ratio(self, negative_to_positive_ratio: float): + def _apply_negative_ratio(self, negative_to_positive_ratio: Optional[float]): """ Apply the negative_to_positive_ratio to reduce the number of negative (common class) examples. Args: negative_to_positive_ratio: The ratio of negative samples to keep relative to positive samples. + If None, preserves the original ratio from the saved index. """ + # If no ratio specified, preserve the original ratio (don't downsample) + if negative_to_positive_ratio is None: + LOG.info( + "Preserving original ratio: %d negative examples to %d positive examples (%.1f:1)", + self.negative_embeddings.shape[0], + self.positive_embeddings.shape[0], + self.negative_embeddings.shape[0] / self.positive_embeddings.shape[0], + ) + return + # Calculate the number of negative samples to keep num_negative_to_keep = int( self.positive_embeddings.shape[0] * negative_to_positive_ratio From 74397ba86382a57b7525bb2f5d5403e9e0771c82 Mon Sep 17 00:00:00 2001 From: rafainn Date: Sun, 17 Aug 2025 23:08:25 +0100 Subject: [PATCH 03/21] DOCS: Add section on testing optimal thresholds and data ratios in README --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 70b184b..05adcbf 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,10 @@ saved_config = index.save( aws_secret_access_key="YOUR_SECRET_ACCESS_KEY" # Optional if using environment credentials ) +## Testing for optimal Thresholds and data ratio's + +Usage of the 'examples\Example_Threshold_Script.py' script will allow for quick threshold checks for a variety of ratios, by default these are 10:1, 5:1 and 1:1 ratios. This has predefined example chat logs, and should, show optimal settings for the dataset being used based on an average score and average detection count. + ## Choosing an aggregation strategy Different deployments optimize for different trade‑offs. You can swap in any aggregator using the `aggregation_function` argument: From 87ad66a5a8f58875d7c909420f8c98984824bf0c Mon Sep 17 00:00:00 2001 From: rafainn Date: Mon, 18 Aug 2025 14:21:02 +0100 Subject: [PATCH 04/21] FEAT: Added a review mode for detailed information regarding each case with optional flags. DOCS: Updated relevent documentation with these fixes --- README.md | 1 + examples/Example_Threshold_Script.py | 107 ++++++++++++++++++++++++++- 2 files changed, 104 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 05adcbf..b14d338 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ saved_config = index.save( ## Testing for optimal Thresholds and data ratio's Usage of the 'examples\Example_Threshold_Script.py' script will allow for quick threshold checks for a variety of ratios, by default these are 10:1, 5:1 and 1:1 ratios. This has predefined example chat logs, and should, show optimal settings for the dataset being used based on an average score and average detection count. +You will be able to get detailed information output in the ratios with the -r or --review flags. ## Choosing an aggregation strategy diff --git a/examples/Example_Threshold_Script.py b/examples/Example_Threshold_Script.py index b0bf48e..944e5a9 100644 --- a/examples/Example_Threshold_Script.py +++ b/examples/Example_Threshold_Script.py @@ -4,6 +4,11 @@ This module creates realistic user profiles with different speech patterns to test threshold behavior and aggregation performance. +To better review statistics, messages and reasoning run the file with a --review or -r flag +Current optimal settings are at a 5:1 at a 0.01 temperature, as 1:1 at any temperature either underflags or gives false flags. + + +poetry run python examples/Example_Threshold_Script.py [FLAGS] User Types: - Normal Speech Only (3 users) @@ -119,15 +124,22 @@ def create_user_profiles() -> Dict[str, List[str]]: return users -def test_thresholds_and_ratios(): - """Test different threshold and ratio combinations.""" +def test_thresholds_and_ratios(review_mode: bool = False): + """Test different threshold and ratio combinations. + + Args: + review_mode: If True, shows detailed per-user analysis and missed detections + """ print("🧪 COMPREHENSIVE SENTINEL TESTING") + if review_mode: + print("📋 REVIEW MODE: Detailed analysis enabled") print("=" * 50) # Load index with different ratios ratios_to_test = [10.0, 5.0, 1.0] - thresholds_to_test = [0.0, 0.01, 0.05, 0.1] + # Temperature as defined in the model. + thresholds_to_test = [0.0, 0.01, 0.05, 0.1] users = create_user_profiles() @@ -201,15 +213,102 @@ def test_thresholds_and_ratios(): if normal_scores and problematic_scores: separation = np.mean(problematic_scores) - np.mean(normal_scores) print(f" 📈 Separation: {separation:.4f} (higher is better)") + + # Detailed review mode analysis + if review_mode: + print("\n🔍 DETAILED REVIEW:") + + # Analyze missed detections (problematic users with low scores) + missed_detections = [] + false_positives = [] + + for category, user_list in categories.items(): + if not user_list: + continue + + print(f"\n 📊 {category} Analysis:") + + for user in user_list: + user_result = results[user] + score = user_result['overall_score'] + detections = user_result['positive_count'] + max_score = user_result['max_score'] + + # Determine if this should be flagged based on category + should_be_flagged = not user.startswith('normal_') + is_flagged = score > 0.1 # Using 0.1 as detection threshold + + status = "✅" if (should_be_flagged == is_flagged) else "❌" + + print(f" {status} {user:15}: score={score:.4f}, detections={detections:2d}, max={max_score:.4f}") + + # Track problematic cases + if should_be_flagged and not is_flagged: + missed_detections.append({ + 'user': user, + 'category': category, + 'score': score, + 'detections': detections, + 'max_score': max_score + }) + elif not should_be_flagged and is_flagged: + false_positives.append({ + 'user': user, + 'category': category, + 'score': score, + 'detections': detections, + 'max_score': max_score + }) + + # Summary of issues + if missed_detections: + print(f"\n ⚠️ MISSED DETECTIONS ({len(missed_detections)}):") + for miss in missed_detections: + print(f" - {miss['user']} ({miss['category']}): score={miss['score']:.4f}") + + if false_positives: + print(f"\n 🚨 FALSE POSITIVES ({len(false_positives)}):") + for fp in false_positives: + print(f" - {fp['user']} ({fp['category']}): score={fp['score']:.4f}") + + if not missed_detections and not false_positives: + print(f"\n 🎯 PERFECT CLASSIFICATION at threshold {threshold}") + + # Show sample messages for problematic cases + if missed_detections and len(missed_detections) <= 2: + print(f"\n 📝 Sample messages from missed detections:") + for miss in missed_detections[:2]: + user_name = miss['user'] + user_messages = users[user_name] + print(f"\n {user_name} messages (showing first 5):") + for i, msg in enumerate(user_messages[:5]): + print(f" {i+1}. \"{msg}\"") + + print(f"\n 📊 Classification Summary:") + total_users = len([u for cat in categories.values() for u in cat]) + correct_classifications = total_users - len(missed_detections) - len(false_positives) + accuracy = correct_classifications / total_users if total_users > 0 else 0 + print(f" Accuracy: {accuracy:.2%} ({correct_classifications}/{total_users})") + print(f" Missed: {len(missed_detections)}, False Positives: {len(false_positives)}") def main(): """Run the comprehensive testing.""" + import sys + + # Check for review flag + review_mode = '--review' in sys.argv or '-r' in sys.argv + + if review_mode: + print("🔍 Review mode enabled - showing detailed analysis") + # Set random seed for reproducible results np.random.seed(42) - test_thresholds_and_ratios() + test_thresholds_and_ratios(review_mode=review_mode) print(f"\n✅ Testing complete! Check results above to determine optimal threshold and ratio.") + if not review_mode: + print("💡 Tip: Run with --review or -r flag for detailed per-user analysis") if __name__ == "__main__": main() From df18d4f33734fd965016b1376b91f5189be37ce3 Mon Sep 17 00:00:00 2001 From: rafainn Date: Mon, 18 Aug 2025 20:14:49 +0100 Subject: [PATCH 05/21] FEAT: Enhance threshold testing with performance metrics and results-only mode --- examples/Example_Threshold_Script.py | 180 +++++++++++++++++++++++++-- 1 file changed, 169 insertions(+), 11 deletions(-) diff --git a/examples/Example_Threshold_Script.py b/examples/Example_Threshold_Script.py index 944e5a9..edcad03 100644 --- a/examples/Example_Threshold_Script.py +++ b/examples/Example_Threshold_Script.py @@ -4,11 +4,15 @@ This module creates realistic user profiles with different speech patterns to test threshold behavior and aggregation performance. -To better review statistics, messages and reasoning run the file with a --review or -r flag +This will also show the time scale for each itteration to ensure efficient +output time and load times for better optimisation. + Current optimal settings are at a 5:1 at a 0.01 temperature, as 1:1 at any temperature either underflags or gives false flags. poetry run python examples/Example_Threshold_Script.py [FLAGS] + --review, -r Show detailed per-user analysis and missed detections + --results-only Show only rare class affinity scores per user (compact output) User Types: - Normal Speech Only (3 users) @@ -20,6 +24,7 @@ from sentinel.sentinel_local_index import SentinelLocalIndex import numpy as np +import time from typing import Dict, List, Tuple def create_user_profiles() -> Dict[str, List[str]]: @@ -124,35 +129,49 @@ def create_user_profiles() -> Dict[str, List[str]]: return users -def test_thresholds_and_ratios(review_mode: bool = False): +def test_thresholds_and_ratios(review_mode: bool = False, results_only: bool = False): """Test different threshold and ratio combinations. Args: review_mode: If True, shows detailed per-user analysis and missed detections + results_only: If True, shows only the rare class affinity scores per user """ + overall_start_time = time.time() + print("🧪 COMPREHENSIVE SENTINEL TESTING") if review_mode: print("📋 REVIEW MODE: Detailed analysis enabled") + elif results_only: + print("📊 RESULTS ONLY MODE: Showing rare class affinity scores per user") + print("⏱️ PERFORMANCE METRICS: Timing data collection enabled") print("=" * 50) # Load index with different ratios - ratios_to_test = [10.0, 5.0, 1.0] + ratios_to_test = [10.0, 5.0, 4.0, 3.0, 2.0, 1.0] # Temperature as defined in the model. - thresholds_to_test = [0.0, 0.01, 0.05, 0.1] + thresholds_to_test = [0.0, 0.01, 0.25, 0.05, 0.1] users = create_user_profiles() + # Performance tracking + total_load_time = 0 + total_analysis_time = 0 + performance_metrics = [] + for ratio in ratios_to_test: ratio_name = "Original" if ratio is None else f"{ratio}:1" print(f"\n📊 TESTING RATIO: {ratio_name}") print("-" * 30) - # Load index with specific ratio + # Time data loading + load_start = time.time() index = SentinelLocalIndex.load( path="path/to/local/index", negative_to_positive_ratio=ratio ) + load_time = time.time() - load_start + total_load_time += load_time # Or load from S3 @@ -164,10 +183,14 @@ def test_thresholds_and_ratios(review_mode: bool = False): # ) print(f"Loaded shapes: pos={index.positive_embeddings.shape[0]}, neg={index.negative_embeddings.shape[0]}") + print(f"⏱️ Load time: {load_time:.3f}s") for threshold in thresholds_to_test: print(f"\n🎯 Threshold: {threshold}") + # Time analysis phase + analysis_start = time.time() + results = {} # Test each user @@ -183,10 +206,94 @@ def test_thresholds_and_ratios(review_mode: bool = False): 'overall_score': result.rare_class_affinity_score, 'positive_count': len(positive_scores), 'max_score': max(result.observation_scores.values()) if result.observation_scores.values() else 0, - 'avg_positive': np.mean(positive_scores) if positive_scores else 0 + 'avg_positive': np.mean(positive_scores) if positive_scores else 0, + 'result_object': result # Store the full result for results_only mode + } + + analysis_time = time.time() - analysis_start + total_analysis_time += analysis_time + + # Calculate comprehensive metrics + normal_users = [k for k in results.keys() if k.startswith('normal_')] + hate_users = [k for k in results.keys() if k.startswith('hate_')] + sexual_users = [k for k in results.keys() if k.startswith('sexual_')] + mixed_users = [k for k in results.keys() if k.startswith('mixed_')] + all_types_users = [k for k in results.keys() if k.startswith('all_types')] + + # Message-level metrics + total_normal_messages = len(normal_users) * 15 # 15 normal messages per normal user + total_hate_messages = len(hate_users) * 10 + len(mixed_users) * 5 + len(all_types_users) * 7 # Actual hate messages + total_sexual_messages = len(sexual_users) * 10 + len(mixed_users) * 5 + len(all_types_users) * 7 # Actual sexual messages + total_problematic_messages = total_hate_messages + total_sexual_messages + + # False positives (normal messages flagged as problematic) + false_positive_messages = sum(results[user]['positive_count'] for user in normal_users) + false_positive_rate_messages = false_positive_messages / total_normal_messages if total_normal_messages > 0 else 0 + + # True positives (hate/mixed messages correctly flagged - focusing on hate detection) + true_positive_messages = sum(results[user]['positive_count'] for user in hate_users + mixed_users + all_types_users) + true_positive_rate = true_positive_messages / total_problematic_messages if total_problematic_messages > 0 else 0 + + # User-level false positives (for comparison) + false_positive_users = sum(1 for user in normal_users if results[user]['overall_score'] > 0.01) + false_positive_rate_users = false_positive_users / len(normal_users) if normal_users else 0 + + # Overall accuracy (correct classifications / total messages) + total_messages = len(results) * 15 # Each user has 15 messages + correct_normal_messages = total_normal_messages - false_positive_messages + correct_problematic_messages = true_positive_messages + message_accuracy = (correct_normal_messages + correct_problematic_messages) / total_messages + + # Store comprehensive performance metrics + performance_metrics.append({ + 'ratio': ratio, + 'threshold': threshold, + 'load_time': load_time, + 'analysis_time': analysis_time, + 'false_positive_rate_messages': false_positive_rate_messages, + 'false_positive_rate_users': false_positive_rate_users, + 'false_positive_messages': false_positive_messages, + 'false_positive_users': false_positive_users, + 'true_positive_rate': true_positive_rate, + 'true_positive_messages': true_positive_messages, + 'message_accuracy': message_accuracy, + 'total_messages': total_messages, + 'total_normal_messages': total_normal_messages, + 'total_problematic_messages': total_problematic_messages + }) + + print(f"⏱️ Analysis: {analysis_time:.3f}s | FP Messages: {false_positive_messages}/{total_normal_messages} ({false_positive_rate_messages:.1%}) | TP Rate: {true_positive_rate:.1%} | Accuracy: {message_accuracy:.1%}") + + # Results-only mode: show just the scores per user + if results_only: + print(f"\n🎯 User Rare Class Affinity Scores (Threshold: {threshold}):") + + # Group users by category for organized display + categories = { + 'Normal Users': [k for k in results.keys() if k.startswith('normal_')], + 'Hate Users': [k for k in results.keys() if k.startswith('hate_')], + 'Sexual Users': [k for k in results.keys() if k.startswith('sexual_')], + 'Mixed Users': [k for k in results.keys() if k.startswith('mixed_')], + 'All Types': [k for k in results.keys() if k.startswith('all_types')] } + + for category, user_list in categories.items(): + if not user_list: + continue + print(f"\n 📊 {category}:") + for user in user_list: + score = results[user]['overall_score'] + detections = results[user]['positive_count'] + # Enhanced FP indicator with message context + if user.startswith('normal_') and detections > 0: + fp_indicator = f" ⚠️ FP ({detections}/15 msgs)" + else: + fp_indicator = "" + print(f" {user:20}: score={score:.4f} ({detections} detections){fp_indicator}") + + continue # Skip the normal summary output - # Categorize and display results + # Categorize and display results (normal mode) categories = { 'Normal Users': [k for k in results.keys() if k.startswith('normal_')], 'Hate Users': [k for k in results.keys() if k.startswith('hate_')], @@ -291,24 +398,75 @@ def test_thresholds_and_ratios(review_mode: bool = False): print(f" Accuracy: {accuracy:.2%} ({correct_classifications}/{total_users})") print(f" Missed: {len(missed_detections)}, False Positives: {len(false_positives)}") + # Performance summary + overall_time = time.time() - overall_start_time + + print(f"\n⏱️ PERFORMANCE SUMMARY") + print("=" * 50) + print(f"Total execution time: {overall_time:.3f}s") + print(f"Total data load time: {total_load_time:.3f}s ({total_load_time/overall_time:.1%})") + print(f"Total analysis time: {total_analysis_time:.3f}s ({total_analysis_time/overall_time:.1%})") + + # Best performance metrics + if performance_metrics: + print(f"\n📊 OPTIMIZATION METRICS:") + + # Find optimal configurations based on comprehensive metrics + zero_fp_configs = [m for m in performance_metrics if m['false_positive_rate_messages'] == 0] + if zero_fp_configs: + # Among zero FP configs, find best true positive rate + best_zero_fp = max(zero_fp_configs, key=lambda x: x['true_positive_rate']) + print(f"Best zero false positive: {best_zero_fp['ratio']}:1 @ {best_zero_fp['threshold']} (TP: {best_zero_fp['true_positive_rate']:.1%}, Analysis: {best_zero_fp['analysis_time']:.3f}s)") + + # Best overall accuracy + best_accuracy = max(performance_metrics, key=lambda x: x['message_accuracy']) + print(f"Best message accuracy: {best_accuracy['ratio']}:1 @ {best_accuracy['threshold']} ({best_accuracy['message_accuracy']:.1%} accuracy, {best_accuracy['false_positive_rate_messages']:.1%} FP)") + + # Fastest with reasonable accuracy (>90%) + fast_accurate = [m for m in performance_metrics if m['message_accuracy'] > 0.9] + if fast_accurate: + fastest_accurate = min(fast_accurate, key=lambda x: x['analysis_time']) + print(f"Fastest with >90% accuracy: {fastest_accurate['ratio']}:1 @ {fastest_accurate['threshold']} ({fastest_accurate['analysis_time']:.3f}s, {fastest_accurate['message_accuracy']:.1%} acc)") + + fastest_overall = min(performance_metrics, key=lambda x: x['analysis_time']) + print(f"Fastest analysis overall: {fastest_overall['ratio']}:1 @ {fastest_overall['threshold']} ({fastest_overall['analysis_time']:.3f}s, {fastest_overall['message_accuracy']:.1%} acc, {fastest_overall['false_positive_rate_messages']:.1%} FP)") + + # Performance by ratio with message-level metrics + print(f"\nMessage-level performance by ratio:") + for ratio in ratios_to_test: + ratio_metrics = [m for m in performance_metrics if m['ratio'] == ratio] + avg_time = np.mean([m['analysis_time'] for m in ratio_metrics]) + avg_fp_rate = np.mean([m['false_positive_rate_messages'] for m in ratio_metrics]) + avg_tp_rate = np.mean([m['true_positive_rate'] for m in ratio_metrics]) + avg_accuracy = np.mean([m['message_accuracy'] for m in ratio_metrics]) + print(f" {ratio}:1 ratio: {avg_time:.3f}s avg, {avg_fp_rate:.1%} FP rate, {avg_tp_rate:.1%} TP rate, {avg_accuracy:.1%} accuracy") + def main(): """Run the comprehensive testing.""" import sys - # Check for review flag + # Check for flags review_mode = '--review' in sys.argv or '-r' in sys.argv + results_only = '--results-only' in sys.argv or '--results' in sys.argv + + # Ensure mutually exclusive modes + if review_mode and results_only: + print("❌ Error: Cannot use both --review and --results-only flags simultaneously") + sys.exit(1) if review_mode: print("🔍 Review mode enabled - showing detailed analysis") + elif results_only: + print("📊 Results-only mode enabled - showing rare class affinity scores per user") # Set random seed for reproducible results np.random.seed(42) - test_thresholds_and_ratios(review_mode=review_mode) + test_thresholds_and_ratios(review_mode=review_mode, results_only=results_only) print(f"\n✅ Testing complete! Check results above to determine optimal threshold and ratio.") - if not review_mode: - print("💡 Tip: Run with --review or -r flag for detailed per-user analysis") + if not review_mode and not results_only: + print("💡 Tip: Run with --review (-r) for detailed analysis or --results-only for user scores only") if __name__ == "__main__": main() From edbdc4d5a4df86e7399b1e3b4a89a18f528d427e Mon Sep 17 00:00:00 2001 From: rafainn Date: Mon, 18 Aug 2025 20:20:40 +0100 Subject: [PATCH 06/21] FEAT: Set default negative_to_positive_ratio to 5.0 and add error handling and fallback for ratio adjustments --- src/sentinel/sentinel_local_index.py | 52 ++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/src/sentinel/sentinel_local_index.py b/src/sentinel/sentinel_local_index.py index a2da002..6dda29d 100644 --- a/src/sentinel/sentinel_local_index.py +++ b/src/sentinel/sentinel_local_index.py @@ -175,7 +175,7 @@ def load( path: str, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, - negative_to_positive_ratio: Optional[float] = None, + negative_to_positive_ratio: Optional[float] = 5.0, ) -> "SentinelLocalIndex": """ Load the index from a path and returns a new SentinelLocalIndex instance. @@ -185,7 +185,8 @@ def load( aws_access_key_id: Optional AWS access key ID for S3 access. aws_secret_access_key: Optional AWS secret access key for S3 access. negative_to_positive_ratio: Ratio of negative examples to keep relative to positive examples. - If None (default), preserves the original ratio from the saved index. + If None, preserves the original ratio from the saved index. + If 5.0 (default), uses a 5:1 negative to positive ratio for optimal performance. If specified, downsamples negative examples to achieve the desired ratio. Returns: @@ -228,8 +229,9 @@ def _apply_negative_ratio(self, negative_to_positive_ratio: Optional[float]): Args: negative_to_positive_ratio: The ratio of negative samples to keep relative to positive samples. If None, preserves the original ratio from the saved index. + If 5.0 (default), uses optimized 5:1 ratio for best performance. """ - # If no ratio specified, preserve the original ratio (don't downsample) + # Handle null/invalid inputs - preserve original ratio if any issues occur if negative_to_positive_ratio is None: LOG.info( "Preserving original ratio: %d negative examples to %d positive examples (%.1f:1)", @@ -239,10 +241,34 @@ def _apply_negative_ratio(self, negative_to_positive_ratio: Optional[float]): ) return + # Check for null embeddings + if self.positive_embeddings is None or self.negative_embeddings is None: + LOG.warning("Null embeddings detected - cannot apply ratio adjustment") + return + + # Check for empty embeddings + if self.positive_embeddings.shape[0] == 0 or self.negative_embeddings.shape[0] == 0: + LOG.warning("Empty embeddings detected - cannot apply ratio adjustment") + return + + # Check for invalid ratio values + if negative_to_positive_ratio <= 0: + LOG.warning("Invalid ratio %f - must be positive. Preserving original ratio.", negative_to_positive_ratio) + return + # Calculate the number of negative samples to keep - num_negative_to_keep = int( - self.positive_embeddings.shape[0] * negative_to_positive_ratio - ) + try: + num_negative_to_keep = int( + self.positive_embeddings.shape[0] * negative_to_positive_ratio + ) + except (ValueError, OverflowError, TypeError) as e: + LOG.warning("Error calculating negative samples to keep: %s. Preserving original ratio.", str(e)) + return + + # Check if calculation resulted in valid number + if num_negative_to_keep <= 0: + LOG.warning("Calculated negative samples to keep is %d - invalid. Preserving original ratio.", num_negative_to_keep) + return if self.negative_embeddings.shape[0] > num_negative_to_keep: LOG.info( @@ -250,11 +276,15 @@ def _apply_negative_ratio(self, negative_to_positive_ratio: Optional[float]): num_negative_to_keep, self.negative_embeddings.shape[0], ) - # Randomly select a subset of the negative examples - indices = torch.randperm(self.negative_embeddings.shape[0])[ - :num_negative_to_keep - ] - self.negative_embeddings = self.negative_embeddings[indices] + # Randomly select a subset of the negative examples with error handling + try: + indices = torch.randperm(self.negative_embeddings.shape[0])[ + :num_negative_to_keep + ] + self.negative_embeddings = self.negative_embeddings[indices] + except (RuntimeError, IndexError, TypeError) as e: + LOG.error("Error during negative embedding downsampling: %s. Preserving original embeddings.", str(e)) + return else: LOG.info( "User requested %d negative examples but the model loaded only has %d", From f3c08a8feb2394050e5ff9be3369cb60e8dcdddc Mon Sep 17 00:00:00 2001 From: rafainn Date: Mon, 18 Aug 2025 20:43:54 +0100 Subject: [PATCH 07/21] FIX: Update user profile creation to exclude sexual content examples Feat: Adjusts message metrics dynamically --- examples/Example_Threshold_Script.py | 53 ++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/examples/Example_Threshold_Script.py b/examples/Example_Threshold_Script.py index edcad03..04b0875 100644 --- a/examples/Example_Threshold_Script.py +++ b/examples/Example_Threshold_Script.py @@ -74,6 +74,11 @@ def create_user_profiles() -> Dict[str, List[str]]: ] # Sexual content examples + + """ + Requires an index with sexual positive examples, current database only focuses on hate-speech. + As such information relating to sexual content has been removed from the codebase but can be added back by removing notes. + sexual_content = [ "want to see me naked?", "send me some hot pics", @@ -91,6 +96,7 @@ def create_user_profiles() -> Dict[str, List[str]]: "we could be alone together", "send nudes" ] + """ users = {} @@ -107,24 +113,24 @@ def create_user_profiles() -> Dict[str, List[str]]: # Sexual Content Focused Users (2 users) for i in range(1, 3): - sexual_msgs = np.random.choice(sexual_content, size=10, replace=False).tolist() +# sexual_msgs = np.random.choice(sexual_content, size=10, replace=False).tolist() # Requires an index with positive sexual content examples normal_msgs = np.random.choice(normal_speech, size=5, replace=False).tolist() - users[f"sexual_user_{i}"] = sexual_msgs + normal_msgs + users[f"sexual_user_{i}"] = normal_msgs np.random.shuffle(users[f"sexual_user_{i}"]) # Mixed Content Users (2 users) for i in range(1, 3): - hate_msgs = np.random.choice(hate_speech, size=5, replace=False).tolist() - sexual_msgs = np.random.choice(sexual_content, size=5, replace=False).tolist() + hate_msgs = np.random.choice(hate_speech, size=5, replace=False).tolist() +# sexual_msgs = np.random.choice(sexual_content, size=5, replace=False).tolist() # Requires an index with positive sexual content examples normal_msgs = np.random.choice(normal_speech, size=5, replace=False).tolist() - users[f"mixed_user_{i}"] = hate_msgs + sexual_msgs + normal_msgs + users[f"mixed_user_{i}"] = hate_msgs + normal_msgs np.random.shuffle(users[f"mixed_user_{i}"]) # All Types Combined User (1 user) hate_msgs = np.random.choice(hate_speech, size=7, replace=False).tolist() - sexual_msgs = np.random.choice(sexual_content, size=7, replace=False).tolist() +# sexual_msgs = np.random.choice(sexual_content, size=7, replace=False).tolist() # Requires an index with positive sexual content examples normal_msgs = np.random.choice(normal_speech, size=6, replace=False).tolist() - users["all_types_user"] = hate_msgs + sexual_msgs + normal_msgs + users["all_types_user"] = hate_msgs + normal_msgs np.random.shuffle(users["all_types_user"]) return users @@ -220,11 +226,28 @@ def test_thresholds_and_ratios(review_mode: bool = False, results_only: bool = F mixed_users = [k for k in results.keys() if k.startswith('mixed_')] all_types_users = [k for k in results.keys() if k.startswith('all_types')] - # Message-level metrics - total_normal_messages = len(normal_users) * 15 # 15 normal messages per normal user - total_hate_messages = len(hate_users) * 10 + len(mixed_users) * 5 + len(all_types_users) * 7 # Actual hate messages - total_sexual_messages = len(sexual_users) * 10 + len(mixed_users) * 5 + len(all_types_users) * 7 # Actual sexual messages - total_problematic_messages = total_hate_messages + total_sexual_messages + # Message-level metrics - dynamically calculate based on actual user profiles + total_normal_messages = 0 + total_hate_messages = 0 + total_sexual_messages = 0 + + # Count actual messages for each user type + for user_name, messages in users.items(): + if user_name.startswith('normal_'): + total_normal_messages += len(messages) + elif user_name.startswith('hate_') or user_name.startswith('mixed_') or user_name.startswith('all_types'): + # Count hate messages in these users (excluding normal messages) + if user_name.startswith('hate_'): + total_hate_messages += 10 # 10 hate messages per hate user + elif user_name.startswith('mixed_'): + total_hate_messages += 5 # 5 hate messages per mixed user + elif user_name.startswith('all_types'): + total_hate_messages += 7 # 7 hate messages per all_types user + elif user_name.startswith('sexual_'): + # Currently sexual users only have normal messages (sexual content commented out) + total_normal_messages += len(messages) + + total_problematic_messages = total_hate_messages + total_sexual_messages # False positives (normal messages flagged as problematic) false_positive_messages = sum(results[user]['positive_count'] for user in normal_users) @@ -239,7 +262,7 @@ def test_thresholds_and_ratios(review_mode: bool = False, results_only: bool = F false_positive_rate_users = false_positive_users / len(normal_users) if normal_users else 0 # Overall accuracy (correct classifications / total messages) - total_messages = len(results) * 15 # Each user has 15 messages + total_messages = sum(len(messages) for messages in users.values()) # Dynamic total based on actual message counts correct_normal_messages = total_normal_messages - false_positive_messages correct_problematic_messages = true_positive_messages message_accuracy = (correct_normal_messages + correct_problematic_messages) / total_messages @@ -416,11 +439,11 @@ def test_thresholds_and_ratios(review_mode: bool = False, results_only: bool = F if zero_fp_configs: # Among zero FP configs, find best true positive rate best_zero_fp = max(zero_fp_configs, key=lambda x: x['true_positive_rate']) - print(f"Best zero false positive: {best_zero_fp['ratio']}:1 @ {best_zero_fp['threshold']} (TP: {best_zero_fp['true_positive_rate']:.1%}, Analysis: {best_zero_fp['analysis_time']:.3f}s)") + print(f"Best zero false positive: {best_zero_fp['ratio']}:1 @ {best_zero_fp['threshold']} ({best_zero_fp['analysis_time']:.3f}s TP: {best_zero_fp['true_positive_rate']:.1%})") # Best overall accuracy best_accuracy = max(performance_metrics, key=lambda x: x['message_accuracy']) - print(f"Best message accuracy: {best_accuracy['ratio']}:1 @ {best_accuracy['threshold']} ({best_accuracy['message_accuracy']:.1%} accuracy, {best_accuracy['false_positive_rate_messages']:.1%} FP)") + print(f"Best message accuracy: {best_accuracy['ratio']}:1 @ {best_accuracy['threshold']} ({best_accuracy['analysis_time']:.3f}s {best_accuracy['message_accuracy']:.1%} accuracy, {best_accuracy['false_positive_rate_messages']:.1%} FP)") # Fastest with reasonable accuracy (>90%) fast_accurate = [m for m in performance_metrics if m['message_accuracy'] > 0.9] From 41f32e861f8b85bade25ee6b98d0db97b2c580f0 Mon Sep 17 00:00:00 2001 From: rafainn Date: Mon, 18 Aug 2025 21:00:02 +0100 Subject: [PATCH 08/21] FIX: Comment out sexual content examples in user profile creation to prevent inclusion --- examples/Example_Threshold_Script.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/Example_Threshold_Script.py b/examples/Example_Threshold_Script.py index 04b0875..6586a2c 100644 --- a/examples/Example_Threshold_Script.py +++ b/examples/Example_Threshold_Script.py @@ -108,14 +108,14 @@ def create_user_profiles() -> Dict[str, List[str]]: for i in range(1, 3): hate_msgs = np.random.choice(hate_speech, size=10, replace=False).tolist() normal_msgs = np.random.choice(normal_speech, size=5, replace=False).tolist() - users[f"hate_user_{i}"] = hate_msgs + normal_msgs + users[f"hate_user_{i}"] = hate_msgs + normal_msgs np.random.shuffle(users[f"hate_user_{i}"]) # Sexual Content Focused Users (2 users) for i in range(1, 3): # sexual_msgs = np.random.choice(sexual_content, size=10, replace=False).tolist() # Requires an index with positive sexual content examples normal_msgs = np.random.choice(normal_speech, size=5, replace=False).tolist() - users[f"sexual_user_{i}"] = normal_msgs + users[f"sexual_user_{i}"] = normal_msgs # + sexual_msgs np.random.shuffle(users[f"sexual_user_{i}"]) # Mixed Content Users (2 users) @@ -123,14 +123,14 @@ def create_user_profiles() -> Dict[str, List[str]]: hate_msgs = np.random.choice(hate_speech, size=5, replace=False).tolist() # sexual_msgs = np.random.choice(sexual_content, size=5, replace=False).tolist() # Requires an index with positive sexual content examples normal_msgs = np.random.choice(normal_speech, size=5, replace=False).tolist() - users[f"mixed_user_{i}"] = hate_msgs + normal_msgs + users[f"mixed_user_{i}"] = hate_msgs + normal_msgs # + sexual_msgs np.random.shuffle(users[f"mixed_user_{i}"]) # All Types Combined User (1 user) hate_msgs = np.random.choice(hate_speech, size=7, replace=False).tolist() # sexual_msgs = np.random.choice(sexual_content, size=7, replace=False).tolist() # Requires an index with positive sexual content examples normal_msgs = np.random.choice(normal_speech, size=6, replace=False).tolist() - users["all_types_user"] = hate_msgs + normal_msgs + users["all_types_user"] = hate_msgs + normal_msgs # + sexual_msgs np.random.shuffle(users["all_types_user"]) return users From 7849149c071a9bc177c2277abd9f8bb2778dfdb7 Mon Sep 17 00:00:00 2001 From: rafainn Date: Mon, 18 Aug 2025 22:38:57 +0100 Subject: [PATCH 09/21] FEAT: Added caching and caching tests to the main model load, reduced load time of model exponentially after the first caching. TESTS: Updated embedding tests to include caching and its management functions FEAT: Updated the example script for testing purposes to include caching mechanics --- examples/Example_Threshold_Script.py | 219 ++++++++++++++++++++++----- src/sentinel/embeddings/sbert.py | 72 ++++++++- tests/test_embeddings.py | 47 +++++- 3 files changed, 294 insertions(+), 44 deletions(-) diff --git a/examples/Example_Threshold_Script.py b/examples/Example_Threshold_Script.py index 6586a2c..44c1b6a 100644 --- a/examples/Example_Threshold_Script.py +++ b/examples/Example_Threshold_Script.py @@ -7,6 +7,8 @@ This will also show the time scale for each itteration to ensure efficient output time and load times for better optimisation. +Example use case now uses local caching for faster recalls in subsequent requests which should improve performance loading times. + Current optimal settings are at a 5:1 at a 0.01 temperature, as 1:1 at any temperature either underflags or gives false flags. @@ -23,6 +25,7 @@ """ from sentinel.sentinel_local_index import SentinelLocalIndex +from sentinel.embeddings.sbert import clear_model_cache, get_cache_info import numpy as np import time from typing import Dict, List, Tuple @@ -151,8 +154,14 @@ def test_thresholds_and_ratios(review_mode: bool = False, results_only: bool = F elif results_only: print("📊 RESULTS ONLY MODE: Showing rare class affinity scores per user") print("⏱️ PERFORMANCE METRICS: Timing data collection enabled") + print("🔥 MODEL CACHING: Enabled for optimal performance") print("=" * 50) + # Clear cache at start for clean performance measurement + clear_model_cache() + cache_info = get_cache_info() + print(f"🧹 Cache cleared - starting fresh (cache size: {cache_info['cache_size']})") + # Load index with different ratios ratios_to_test = [10.0, 5.0, 4.0, 3.0, 2.0, 1.0] # Temperature as defined in the model. @@ -165,6 +174,9 @@ def test_thresholds_and_ratios(review_mode: bool = False, results_only: bool = F total_analysis_time = 0 performance_metrics = [] + # Track cache performance + first_model_load = True + for ratio in ratios_to_test: ratio_name = "Original" if ratio is None else f"{ratio}:1" print(f"\n📊 TESTING RATIO: {ratio_name}") @@ -179,6 +191,18 @@ def test_thresholds_and_ratios(review_mode: bool = False, results_only: bool = F load_time = time.time() - load_start total_load_time += load_time + # Cache performance tracking + if first_model_load: + print(f"⏱️ First model load time: {load_time:.3f}s (loads from disk + builds cache)") + first_model_load = False + else: + print(f"⏱️ Cached model load time: {load_time:.3f}s (uses cached model)") + + # Display cache status + cache_info_current = get_cache_info() + if cache_info_current['cache_size'] > 0: + print(f"💾 Cache status: {cache_info_current['cache_size']} models cached - {cache_info_current['cached_models']}") + # Or load from S3 # index = SentinelLocalIndex.load( @@ -226,46 +250,126 @@ def test_thresholds_and_ratios(review_mode: bool = False, results_only: bool = F mixed_users = [k for k in results.keys() if k.startswith('mixed_')] all_types_users = [k for k in results.keys() if k.startswith('all_types')] - # Message-level metrics - dynamically calculate based on actual user profiles + # Message-level metrics - need to properly track what each message actually is total_normal_messages = 0 total_hate_messages = 0 total_sexual_messages = 0 - # Count actual messages for each user type + # Count actual messages for each user type and track their composition + user_compositions = {} # Track the actual breakdown of each user's messages + for user_name, messages in users.items(): if user_name.startswith('normal_'): total_normal_messages += len(messages) - elif user_name.startswith('hate_') or user_name.startswith('mixed_') or user_name.startswith('all_types'): - # Count hate messages in these users (excluding normal messages) - if user_name.startswith('hate_'): - total_hate_messages += 10 # 10 hate messages per hate user - elif user_name.startswith('mixed_'): - total_hate_messages += 5 # 5 hate messages per mixed user - elif user_name.startswith('all_types'): - total_hate_messages += 7 # 7 hate messages per all_types user + user_compositions[user_name] = {'normal': len(messages), 'hate': 0, 'sexual': 0} + elif user_name.startswith('hate_'): + # Hate users: 10 hate + 5 normal messages + hate_count = 10 + normal_count = 5 + total_hate_messages += hate_count + total_normal_messages += normal_count + user_compositions[user_name] = {'normal': normal_count, 'hate': hate_count, 'sexual': 0} + elif user_name.startswith('mixed_'): + # Mixed users: 5 hate + 5 normal messages + hate_count = 5 + normal_count = 5 + total_hate_messages += hate_count + total_normal_messages += normal_count + user_compositions[user_name] = {'normal': normal_count, 'hate': hate_count, 'sexual': 0} + elif user_name.startswith('all_types'): + # All types user: 7 hate + 6 normal messages + hate_count = 7 + normal_count = 6 + total_hate_messages += hate_count + total_normal_messages += normal_count + user_compositions[user_name] = {'normal': normal_count, 'hate': hate_count, 'sexual': 0} elif user_name.startswith('sexual_'): # Currently sexual users only have normal messages (sexual content commented out) total_normal_messages += len(messages) + user_compositions[user_name] = {'normal': len(messages), 'hate': 0, 'sexual': 0} - total_problematic_messages = total_hate_messages + total_sexual_messages + total_problematic_messages = total_hate_messages + total_sexual_messages - # False positives (normal messages flagged as problematic) - false_positive_messages = sum(results[user]['positive_count'] for user in normal_users) - false_positive_rate_messages = false_positive_messages / total_normal_messages if total_normal_messages > 0 else 0 + # Calculate accurate false positives and true positives + # We need to estimate based on detection rates since we can't track individual message classifications + + # False positives: messages flagged from normal users (definitely all FP) + false_positive_messages_from_normal_users = sum(results[user]['positive_count'] for user in normal_users) + + # For mixed users, we need to estimate FP vs TP based on the composition + # This is an approximation since we don't know which specific messages were flagged + estimated_false_positives_from_mixed_users = 0 + estimated_true_positives = 0 + + for user_name in hate_users + mixed_users + all_types_users: + detections = results[user_name]['positive_count'] + composition = user_compositions[user_name] + + # Estimate: if detection rate is higher than the hate message ratio, + # some normal messages are likely being flagged too + hate_ratio = composition['hate'] / (composition['hate'] + composition['normal']) + total_messages_for_user = composition['hate'] + composition['normal'] + + # Estimate true positives (capped at actual hate messages) + estimated_tp_for_user = min(detections, composition['hate']) + estimated_true_positives += estimated_tp_for_user + + # Estimate false positives (excess detections beyond hate messages) + estimated_fp_for_user = max(0, detections - composition['hate']) + estimated_false_positives_from_mixed_users += estimated_fp_for_user - # True positives (hate/mixed messages correctly flagged - focusing on hate detection) - true_positive_messages = sum(results[user]['positive_count'] for user in hate_users + mixed_users + all_types_users) - true_positive_rate = true_positive_messages / total_problematic_messages if total_problematic_messages > 0 else 0 + # Total false positives and true positives + total_false_positives = false_positive_messages_from_normal_users + estimated_false_positives_from_mixed_users + total_true_positives = estimated_true_positives + + # Corrected rates + false_positive_rate_messages = total_false_positives / total_normal_messages if total_normal_messages > 0 else 0 + true_positive_rate = total_true_positives / total_problematic_messages if total_problematic_messages > 0 else 0 # User-level false positives (for comparison) false_positive_users = sum(1 for user in normal_users if results[user]['overall_score'] > 0.01) false_positive_rate_users = false_positive_users / len(normal_users) if normal_users else 0 - # Overall accuracy (correct classifications / total messages) - total_messages = sum(len(messages) for messages in users.values()) # Dynamic total based on actual message counts - correct_normal_messages = total_normal_messages - false_positive_messages - correct_problematic_messages = true_positive_messages - message_accuracy = (correct_normal_messages + correct_problematic_messages) / total_messages + # Overall accuracy calculation - ensure totals are consistent + total_messages = total_normal_messages + total_problematic_messages # Use calculated totals, not user message counts + + # Fix accuracy calculation - we need to count correctly classified messages vs incorrectly classified + # Alternative approach: Detection accuracy - how well does the system detect problematic content + # True Negatives: Normal messages correctly not flagged + true_negatives = total_normal_messages - total_false_positives + # False Positives: Normal messages incorrectly flagged + false_positives = total_false_positives + # True Positives: Problematic messages correctly detected + true_positives = total_true_positives + # False Negatives: Problematic messages missed + false_negatives = total_problematic_messages - total_true_positives + + # Verify totals add up correctly and debug if needed + calculated_total = true_positives + true_negatives + false_positives + false_negatives + if threshold == 0.0: # Debug for threshold 0.0 only to avoid spam + print(f" DEBUG - Ratio {ratio}:1 | Total: {total_messages} | TP: {true_positives} | TN: {true_negatives} | FP: {false_positives} | FN: {false_negatives} | Sum: {calculated_total}") + print(f" DEBUG - Normal msgs: {total_normal_messages} | Problematic msgs: {total_problematic_messages}") + + if abs(calculated_total - total_messages) > 0.1: # Allow for small floating point errors + print(f"Warning: Total mismatch - calculated: {calculated_total}, expected: {total_messages}") + + # Standard accuracy formula: (TP + TN) / (TP + TN + FP + FN) + message_accuracy = (true_positives + true_negatives) / total_messages if total_messages > 0 else 0 + + # Manual verification of accuracy calculation + if threshold == 0.0: # Debug for threshold 0.0 only + manual_accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives) + print(f" DEBUG - Accuracy check: {message_accuracy:.3f} vs manual: {manual_accuracy:.3f}") + + # Calculate individual component percentages for verification + tp_rate = true_positives / total_problematic_messages if total_problematic_messages > 0 else 0 + tn_rate = true_negatives / total_normal_messages if total_normal_messages > 0 else 0 + fp_rate = false_positives / total_normal_messages if total_normal_messages > 0 else 0 + fn_rate = false_negatives / total_problematic_messages if total_problematic_messages > 0 else 0 + + # Alternative: Detection effectiveness (focusing on the detection task) + detection_accuracy = tp_rate # This is the same as true_positive_rate + classification_accuracy = tn_rate # This is the true negative rate # Store comprehensive performance metrics performance_metrics.append({ @@ -273,19 +377,30 @@ def test_thresholds_and_ratios(review_mode: bool = False, results_only: bool = F 'threshold': threshold, 'load_time': load_time, 'analysis_time': analysis_time, - 'false_positive_rate_messages': false_positive_rate_messages, + 'false_positive_rate_messages': fp_rate, # Use the calculated rate 'false_positive_rate_users': false_positive_rate_users, - 'false_positive_messages': false_positive_messages, + 'false_positive_messages': total_false_positives, 'false_positive_users': false_positive_users, - 'true_positive_rate': true_positive_rate, - 'true_positive_messages': true_positive_messages, + 'true_positive_rate': tp_rate, # Use the calculated rate + 'true_positive_messages': total_true_positives, 'message_accuracy': message_accuracy, + 'detection_accuracy': detection_accuracy, + 'classification_accuracy': classification_accuracy, 'total_messages': total_messages, 'total_normal_messages': total_normal_messages, - 'total_problematic_messages': total_problematic_messages + 'total_problematic_messages': total_problematic_messages, + 'true_positives': true_positives, + 'true_negatives': true_negatives, + 'false_positives': false_positives, + 'false_negatives': false_negatives, + # Add breakdown percentages for verification + 'tp_rate': tp_rate, + 'tn_rate': tn_rate, + 'fp_rate': fp_rate, + 'fn_rate': fn_rate }) - print(f"⏱️ Analysis: {analysis_time:.3f}s | FP Messages: {false_positive_messages}/{total_normal_messages} ({false_positive_rate_messages:.1%}) | TP Rate: {true_positive_rate:.1%} | Accuracy: {message_accuracy:.1%}") + print(f"⏱️ Analysis: {analysis_time:.3f}s | Acc: {message_accuracy:.1%} | TP: {tp_rate:.1%} | FP: {fp_rate:.1%} | TN: {tn_rate:.1%} | FN: {fn_rate:.1%}") # Results-only mode: show just the scores per user if results_only: @@ -435,34 +550,46 @@ def test_thresholds_and_ratios(review_mode: bool = False, results_only: bool = F print(f"\n📊 OPTIMIZATION METRICS:") # Find optimal configurations based on comprehensive metrics - zero_fp_configs = [m for m in performance_metrics if m['false_positive_rate_messages'] == 0] + zero_fp_configs = [m for m in performance_metrics if m['fp_rate'] == 0] if zero_fp_configs: # Among zero FP configs, find best true positive rate - best_zero_fp = max(zero_fp_configs, key=lambda x: x['true_positive_rate']) - print(f"Best zero false positive: {best_zero_fp['ratio']}:1 @ {best_zero_fp['threshold']} ({best_zero_fp['analysis_time']:.3f}s TP: {best_zero_fp['true_positive_rate']:.1%})") + best_zero_fp = max(zero_fp_configs, key=lambda x: x['tp_rate']) + print(f"Best zero false positive: {best_zero_fp['ratio']}:1 @ {best_zero_fp['threshold']} ({best_zero_fp['analysis_time']:.3f}s, Acc: {best_zero_fp['message_accuracy']:.1%}, TP: {best_zero_fp['tp_rate']:.1%}, FP: {best_zero_fp['fp_rate']:.1%})") - # Best overall accuracy + # Best overall accuracy - now using corrected accuracy best_accuracy = max(performance_metrics, key=lambda x: x['message_accuracy']) - print(f"Best message accuracy: {best_accuracy['ratio']}:1 @ {best_accuracy['threshold']} ({best_accuracy['analysis_time']:.3f}s {best_accuracy['message_accuracy']:.1%} accuracy, {best_accuracy['false_positive_rate_messages']:.1%} FP)") + print(f"Best accuracy: {best_accuracy['ratio']}:1 @ {best_accuracy['threshold']} ({best_accuracy['analysis_time']:.3f}s, Acc: {best_accuracy['message_accuracy']:.1%}, TP: {best_accuracy['tp_rate']:.1%}, FP: {best_accuracy['fp_rate']:.1%})") + + # Best true positive rate (detection) + best_tp = max(performance_metrics, key=lambda x: x['tp_rate']) + print(f"Best true positive: {best_tp['ratio']}:1 @ {best_tp['threshold']} ({best_tp['analysis_time']:.3f}s, Acc: {best_tp['message_accuracy']:.1%}, TP: {best_tp['tp_rate']:.1%}, FP: {best_tp['fp_rate']:.1%})") - # Fastest with reasonable accuracy (>90%) - fast_accurate = [m for m in performance_metrics if m['message_accuracy'] > 0.9] - if fast_accurate: - fastest_accurate = min(fast_accurate, key=lambda x: x['analysis_time']) - print(f"Fastest with >90% accuracy: {fastest_accurate['ratio']}:1 @ {fastest_accurate['threshold']} ({fastest_accurate['analysis_time']:.3f}s, {fastest_accurate['message_accuracy']:.1%} acc)") + # Best balance (high accuracy with low FP) + high_accuracy_low_fp = [m for m in performance_metrics if m['fp_rate'] <= 0.02] # ≤2% FP + if high_accuracy_low_fp: + best_balance = max(high_accuracy_low_fp, key=lambda x: x['message_accuracy']) + print(f"Best balanced: {best_balance['ratio']}:1 @ {best_balance['threshold']} ({best_balance['analysis_time']:.3f}s, Acc: {best_balance['message_accuracy']:.1%}, TP: {best_balance['tp_rate']:.1%}, FP: {best_balance['fp_rate']:.1%})") + # Fastest analysis fastest_overall = min(performance_metrics, key=lambda x: x['analysis_time']) - print(f"Fastest analysis overall: {fastest_overall['ratio']}:1 @ {fastest_overall['threshold']} ({fastest_overall['analysis_time']:.3f}s, {fastest_overall['message_accuracy']:.1%} acc, {fastest_overall['false_positive_rate_messages']:.1%} FP)") + print(f"Fastest analysis: {fastest_overall['ratio']}:1 @ {fastest_overall['threshold']} ({fastest_overall['analysis_time']:.3f}s, Acc: {fastest_overall['message_accuracy']:.1%}, TP: {fastest_overall['tp_rate']:.1%}, FP: {fastest_overall['fp_rate']:.1%})") + + # Cache performance summary + final_cache_info = get_cache_info() + print(f"\n🔥 CACHE PERFORMANCE:") + print(f"Models cached: {final_cache_info['cache_size']} ({final_cache_info['cached_models']})") + print(f"Cache benefit: Subsequent model loads are ~1000x faster than initial load") + print(f"Memory optimization: {final_cache_info['memory_info']}") # Performance by ratio with message-level metrics print(f"\nMessage-level performance by ratio:") for ratio in ratios_to_test: ratio_metrics = [m for m in performance_metrics if m['ratio'] == ratio] avg_time = np.mean([m['analysis_time'] for m in ratio_metrics]) - avg_fp_rate = np.mean([m['false_positive_rate_messages'] for m in ratio_metrics]) - avg_tp_rate = np.mean([m['true_positive_rate'] for m in ratio_metrics]) avg_accuracy = np.mean([m['message_accuracy'] for m in ratio_metrics]) - print(f" {ratio}:1 ratio: {avg_time:.3f}s avg, {avg_fp_rate:.1%} FP rate, {avg_tp_rate:.1%} TP rate, {avg_accuracy:.1%} accuracy") + avg_tp_rate = np.mean([m['tp_rate'] for m in ratio_metrics]) + avg_fp_rate = np.mean([m['fp_rate'] for m in ratio_metrics]) + print(f" {ratio}:1 ratio: {avg_time:.3f}s avg | Acc: {avg_accuracy:.1%} | TP: {avg_tp_rate:.1%} | FP: {avg_fp_rate:.1%}") def main(): """Run the comprehensive testing.""" @@ -487,9 +614,17 @@ def main(): test_thresholds_and_ratios(review_mode=review_mode, results_only=results_only) + # Final cache cleanup and summary + final_cache_info = get_cache_info() + if final_cache_info['cache_size'] > 0: + print(f"\n🧹 Cache cleanup: {final_cache_info['cache_size']} models in cache") + clear_model_cache() + print(f"✅ Cache cleared successfully") + print(f"\n✅ Testing complete! Check results above to determine optimal threshold and ratio.") if not review_mode and not results_only: print("💡 Tip: Run with --review (-r) for detailed analysis or --results-only for user scores only") + print("🔥 Performance: Model caching enabled - repeated runs will be faster!") if __name__ == "__main__": main() diff --git a/src/sentinel/embeddings/sbert.py b/src/sentinel/embeddings/sbert.py index 6a776f7..6667df1 100644 --- a/src/sentinel/embeddings/sbert.py +++ b/src/sentinel/embeddings/sbert.py @@ -20,25 +20,50 @@ from typing import Callable, Optional, Tuple import os +import logging +from functools import lru_cache from sentence_transformers import SentenceTransformer +LOG = logging.getLogger(__name__) + +# Global cache for SentenceTransformer models to avoid redundant loading +_model_cache = {} + def get_sentence_transformer_and_scaling_fn( sentence_model_name_or_path: str, + use_cache: bool = True, ) -> Tuple[SentenceTransformer, Optional[Callable[[float], float]]]: """ Create a SentenceTransformer model instance and return it along with an appropriate scaling function. + + Uses caching to avoid loading the same model multiple times, which significantly improves + performance when the same model is used repeatedly. Args: sentence_model_name_or_path: Path or name of the sentence model. + use_cache: Whether to use model caching. Default True for performance. Returns: A tuple containing: - A SentenceTransformer model instance - A scaling function for similarity scores if needed (only for E5 family models), or None """ - model = SentenceTransformer(sentence_model_name_or_path) + global _model_cache + + # Check cache first if caching is enabled + if use_cache and sentence_model_name_or_path in _model_cache: + LOG.debug(f"Loading cached SentenceTransformer model: {sentence_model_name_or_path}") + model = _model_cache[sentence_model_name_or_path] + else: + LOG.debug(f"Creating new SentenceTransformer model: {sentence_model_name_or_path}") + model = SentenceTransformer(sentence_model_name_or_path) + + # Cache the model if caching is enabled + if use_cache: + _model_cache[sentence_model_name_or_path] = model + LOG.debug(f"Cached SentenceTransformer model: {sentence_model_name_or_path}") # Check if the model is from E5 family - they need score scaling if "e5-" in sentence_model_name_or_path.lower(): @@ -46,6 +71,51 @@ def get_sentence_transformer_and_scaling_fn( return model, None + +def clear_model_cache() -> None: + """ + Clear the global model cache to free up memory. + + Useful when switching between different models or when memory usage is a concern. + """ + global _model_cache + cache_size = len(_model_cache) + _model_cache.clear() + LOG.info(f"Cleared model cache ({cache_size} models removed)") + + +def get_cache_info() -> dict: + """ + Get information about the current model cache. + + Returns: + Dictionary containing cache statistics including cached models and memory usage info. + """ + global _model_cache + return { + "cached_models": list(_model_cache.keys()), + "cache_size": len(_model_cache), + "memory_info": "Use clear_model_cache() to free memory if needed" + } + + +def remove_from_cache(model_name: str) -> bool: + """ + Remove a specific model from the cache. + + Args: + model_name: Name/path of the model to remove from cache. + + Returns: + True if model was found and removed, False if not in cache. + """ + global _model_cache + if model_name in _model_cache: + del _model_cache[model_name] + LOG.info(f"Removed model from cache: {model_name}") + return True + return False + def e5_scaling_function(score: float) -> float: """ Scale the similarity score for E5 embeddings. diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 387c593..9211c87 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -17,7 +17,12 @@ import pytest from unittest.mock import patch, MagicMock -from sentinel.embeddings.sbert import get_sentence_transformer_and_scaling_fn +from sentinel.embeddings.sbert import ( + get_sentence_transformer_and_scaling_fn, + clear_model_cache, + get_cache_info, + remove_from_cache, +) class TestSBERT: @@ -89,3 +94,43 @@ def test_model_specific_scaling( assert 0.0 <= scaled <= 1.0 else: assert scale_fn is None + + @patch("sentinel.embeddings.sbert.SentenceTransformer") + def test_model_caching(self, mock_transformer): + """Test that model caching works correctly.""" + # Setup mock + mock_instance = MagicMock() + mock_transformer.return_value = mock_instance + + # Clear cache to start fresh + clear_model_cache() + + model_name = "test-model" + + # First call should create the model + model1, scale_fn1 = get_sentence_transformer_and_scaling_fn(model_name, use_cache=True) + assert mock_transformer.call_count == 1 + + # Second call should use cached model + model2, scale_fn2 = get_sentence_transformer_and_scaling_fn(model_name, use_cache=True) + assert mock_transformer.call_count == 1 # Still 1, no new creation + assert model1 is model2 # Same model instance + + # Test cache info + cache_info = get_cache_info() + assert model_name in cache_info["cached_models"] + assert cache_info["cache_size"] == 1 + + # Test removing from cache + assert remove_from_cache(model_name) is True + assert remove_from_cache(model_name) is False # Already removed + + # Test with caching disabled + model3, scale_fn3 = get_sentence_transformer_and_scaling_fn(model_name, use_cache=False) + assert mock_transformer.call_count == 2 # New creation + + # Test clearing cache + get_sentence_transformer_and_scaling_fn("another-model", use_cache=True) + clear_model_cache() + cache_info_after_clear = get_cache_info() + assert cache_info_after_clear["cache_size"] == 0 From f17a8cae85217d8d16f67cf911dc2a2ae700f008 Mon Sep 17 00:00:00 2001 From: rafainn Date: Tue, 19 Aug 2025 10:38:04 +0100 Subject: [PATCH 10/21] DOCS: Fixed formatting for better PEP8 format --- examples/Example_Threshold_Script.py | 650 +++++++++++++++++---------- src/sentinel/embeddings/sbert.py | 6 +- 2 files changed, 418 insertions(+), 238 deletions(-) diff --git a/examples/Example_Threshold_Script.py b/examples/Example_Threshold_Script.py index 44c1b6a..ccbe6cc 100644 --- a/examples/Example_Threshold_Script.py +++ b/examples/Example_Threshold_Script.py @@ -4,12 +4,14 @@ This module creates realistic user profiles with different speech patterns to test threshold behavior and aggregation performance. -This will also show the time scale for each itteration to ensure efficient -output time and load times for better optimisation. +This will also show the time scale for each iteration to ensure efficient +output time and load times for better optimization. -Example use case now uses local caching for faster recalls in subsequent requests which should improve performance loading times. +Example use case now uses local caching for faster recalls in subsequent requests +which should improve performance loading times. -Current optimal settings are at a 5:1 at a 0.01 temperature, as 1:1 at any temperature either underflags or gives false flags. +Current optimal settings are at a 5:1 at a 0.01 temperature, as 1:1 at any +temperature either underflags or gives false flags. poetry run python examples/Example_Threshold_Script.py [FLAGS] @@ -18,7 +20,7 @@ User Types: - Normal Speech Only (3 users) -- Hate Speech Focused (2 users) +- Hate Speech Focused (2 users) - Sexual Content Focused (2 users) - Mixed Content (2 users) - All Types Combined (1 user) @@ -30,9 +32,10 @@ import time from typing import Dict, List, Tuple + def create_user_profiles() -> Dict[str, List[str]]: """Create 10 different user profiles with varying speech patterns.""" - + # Normal speech examples normal_speech = [ "Hey how are you doing today?", @@ -56,13 +59,13 @@ def create_user_profiles() -> Dict[str, List[str]]: "My family is visiting next week", "That joke was hilarious" ] - + # Hate speech examples hate_speech = [ "I hate all those people", "they don't belong here", "those immigrants are ruining everything", - "whites are superior to everyone else", + "whites are superior to everyone else", "the holocaust never happened", "Hitler had the right idea", "Jews control all the banks", @@ -75,13 +78,15 @@ def create_user_profiles() -> Dict[str, List[str]]: "we need to get rid of them all", "violence is the only solution" ] - + # Sexual content examples - + """ - Requires an index with sexual positive examples, current database only focuses on hate-speech. - As such information relating to sexual content has been removed from the codebase but can be added back by removing notes. - + Requires an index with sexual positive examples, current database only + focuses on hate-speech. + As such information relating to sexual content has been removed from the + codebase but can be added back by removing notes. + sexual_content = [ "want to see me naked?", "send me some hot pics", @@ -100,274 +105,367 @@ def create_user_profiles() -> Dict[str, List[str]]: "send nudes" ] """ - + users = {} - + # Normal Speech Only Users (3 users) for i in range(1, 4): - users[f"normal_user_{i}"] = np.random.choice(normal_speech, size=15, replace=False).tolist() - + users[f"normal_user_{i}"] = np.random.choice( + normal_speech, size=15, replace=False + ).tolist() + # Hate Speech Focused Users (2 users) for i in range(1, 3): - hate_msgs = np.random.choice(hate_speech, size=10, replace=False).tolist() - normal_msgs = np.random.choice(normal_speech, size=5, replace=False).tolist() - users[f"hate_user_{i}"] = hate_msgs + normal_msgs + hate_msgs = np.random.choice( + hate_speech, size=10, replace=False + ).tolist() + normal_msgs = np.random.choice( + normal_speech, size=5, replace=False + ).tolist() + users[f"hate_user_{i}"] = hate_msgs + normal_msgs np.random.shuffle(users[f"hate_user_{i}"]) - - # Sexual Content Focused Users (2 users) + + # Sexual Content Focused Users (2 users) for i in range(1, 3): -# sexual_msgs = np.random.choice(sexual_content, size=10, replace=False).tolist() # Requires an index with positive sexual content examples - normal_msgs = np.random.choice(normal_speech, size=5, replace=False).tolist() - users[f"sexual_user_{i}"] = normal_msgs # + sexual_msgs + # sexual_msgs = np.random.choice(sexual_content, size=10, + # replace=False).tolist() + # Requires an index with positive sexual content examples + normal_msgs = np.random.choice( + normal_speech, size=5, replace=False + ).tolist() + users[f"sexual_user_{i}"] = normal_msgs # + sexual_msgs np.random.shuffle(users[f"sexual_user_{i}"]) - + # Mixed Content Users (2 users) for i in range(1, 3): - hate_msgs = np.random.choice(hate_speech, size=5, replace=False).tolist() -# sexual_msgs = np.random.choice(sexual_content, size=5, replace=False).tolist() # Requires an index with positive sexual content examples - normal_msgs = np.random.choice(normal_speech, size=5, replace=False).tolist() - users[f"mixed_user_{i}"] = hate_msgs + normal_msgs # + sexual_msgs + hate_msgs = np.random.choice( + hate_speech, size=5, replace=False + ).tolist() + # sexual_msgs = np.random.choice(sexual_content, size=5, + # replace=False).tolist() + # Requires an index with positive sexual content examples + normal_msgs = np.random.choice( + normal_speech, size=5, replace=False + ).tolist() + users[f"mixed_user_{i}"] = hate_msgs + normal_msgs # + sexual_msgs np.random.shuffle(users[f"mixed_user_{i}"]) - + # All Types Combined User (1 user) hate_msgs = np.random.choice(hate_speech, size=7, replace=False).tolist() -# sexual_msgs = np.random.choice(sexual_content, size=7, replace=False).tolist() # Requires an index with positive sexual content examples + # sexual_msgs = np.random.choice(sexual_content, size=7, + # replace=False).tolist() + # Requires an index with positive sexual content examples normal_msgs = np.random.choice(normal_speech, size=6, replace=False).tolist() - users["all_types_user"] = hate_msgs + normal_msgs # + sexual_msgs + users["all_types_user"] = hate_msgs + normal_msgs # + sexual_msgs np.random.shuffle(users["all_types_user"]) - + return users -def test_thresholds_and_ratios(review_mode: bool = False, results_only: bool = False): + +def test_thresholds_and_ratios(review_mode: bool = False, + results_only: bool = False): """Test different threshold and ratio combinations. - + Args: - review_mode: If True, shows detailed per-user analysis and missed detections - results_only: If True, shows only the rare class affinity scores per user + review_mode: If True, shows detailed per-user analysis and missed + detections + results_only: If True, shows only the rare class affinity scores + per user """ - + overall_start_time = time.time() - + print("🧪 COMPREHENSIVE SENTINEL TESTING") if review_mode: print("📋 REVIEW MODE: Detailed analysis enabled") elif results_only: - print("📊 RESULTS ONLY MODE: Showing rare class affinity scores per user") + print("📊 RESULTS ONLY MODE: Showing rare class affinity scores " + "per user") print("⏱️ PERFORMANCE METRICS: Timing data collection enabled") print("🔥 MODEL CACHING: Enabled for optimal performance") print("=" * 50) - + # Clear cache at start for clean performance measurement clear_model_cache() cache_info = get_cache_info() - print(f"🧹 Cache cleared - starting fresh (cache size: {cache_info['cache_size']})") - + print(f"🧹 Cache cleared - starting fresh " + f"(cache size: {cache_info['cache_size']})") + # Load index with different ratios ratios_to_test = [10.0, 5.0, 4.0, 3.0, 2.0, 1.0] # Temperature as defined in the model. - thresholds_to_test = [0.0, 0.01, 0.25, 0.05, 0.1] - + thresholds_to_test = [0.0, 0.01, 0.25, 0.05, 0.1] + users = create_user_profiles() - + # Performance tracking total_load_time = 0 total_analysis_time = 0 performance_metrics = [] - + # Track cache performance first_model_load = True - + for ratio in ratios_to_test: ratio_name = "Original" if ratio is None else f"{ratio}:1" print(f"\n📊 TESTING RATIO: {ratio_name}") print("-" * 30) - + # Time data loading load_start = time.time() index = SentinelLocalIndex.load( - path="path/to/local/index", + path="./examples/hate_speech_model", negative_to_positive_ratio=ratio ) load_time = time.time() - load_start total_load_time += load_time - + # Cache performance tracking if first_model_load: - print(f"⏱️ First model load time: {load_time:.3f}s (loads from disk + builds cache)") + print(f"⏱️ First model load time: {load_time:.3f}s " + "(loads from disk + builds cache)") first_model_load = False else: - print(f"⏱️ Cached model load time: {load_time:.3f}s (uses cached model)") - + print(f"⏱️ Cached model load time: {load_time:.3f}s " + "(uses cached model)") + # Display cache status cache_info_current = get_cache_info() if cache_info_current['cache_size'] > 0: - print(f"💾 Cache status: {cache_info_current['cache_size']} models cached - {cache_info_current['cached_models']}") - + print(f"💾 Cache status: {cache_info_current['cache_size']} " + f"models cached - {cache_info_current['cached_models']}") + # Or load from S3 - + # index = SentinelLocalIndex.load( # path="s3://my-bucket/path/to/index", - # aws_access_key_id="YOUR_ACCESS_KEY_ID", # Optional if using environment credentials - # aws_secret_access_key="YOUR_SECRET_ACCESS_KEY", # Optional if using environment credentials + # aws_access_key_id="YOUR_ACCESS_KEY_ID", + # # Optional if using environment credentials + # aws_secret_access_key="YOUR_SECRET_ACCESS_KEY", + # # Optional if using environment credentials # negative_to_positive_ratio=ratio # ) - - print(f"Loaded shapes: pos={index.positive_embeddings.shape[0]}, neg={index.negative_embeddings.shape[0]}") + + print(f"Loaded shapes: pos={index.positive_embeddings.shape[0]}, " + f"neg={index.negative_embeddings.shape[0]}") print(f"⏱️ Load time: {load_time:.3f}s") for threshold in thresholds_to_test: print(f"\n🎯 Threshold: {threshold}") - + # Time analysis phase analysis_start = time.time() - + results = {} - + # Test each user for user_name, messages in users.items(): result = index.calculate_rare_class_affinity( messages, min_score_to_consider=threshold ) - + # Calculate statistics - positive_scores = [score for score in result.observation_scores.values() if score > 0] + positive_scores = [ + score for score in result.observation_scores.values() + if score > 0 + ] results[user_name] = { 'overall_score': result.rare_class_affinity_score, 'positive_count': len(positive_scores), - 'max_score': max(result.observation_scores.values()) if result.observation_scores.values() else 0, - 'avg_positive': np.mean(positive_scores) if positive_scores else 0, - 'result_object': result # Store the full result for results_only mode + 'max_score': max(result.observation_scores.values()) + if result.observation_scores.values() else 0, + 'avg_positive': np.mean(positive_scores) + if positive_scores else 0, + 'result_object': result # Store full result for results_only } analysis_time = time.time() - analysis_start total_analysis_time += analysis_time - + # Calculate comprehensive metrics normal_users = [k for k in results.keys() if k.startswith('normal_')] hate_users = [k for k in results.keys() if k.startswith('hate_')] - sexual_users = [k for k in results.keys() if k.startswith('sexual_')] + sexual_users = [k for k in results.keys() + if k.startswith('sexual_')] mixed_users = [k for k in results.keys() if k.startswith('mixed_')] - all_types_users = [k for k in results.keys() if k.startswith('all_types')] - - # Message-level metrics - need to properly track what each message actually is + all_types_users = [k for k in results.keys() + if k.startswith('all_types')] + + # Message-level metrics - need to properly track what each + # message actually is total_normal_messages = 0 total_hate_messages = 0 total_sexual_messages = 0 - - # Count actual messages for each user type and track their composition - user_compositions = {} # Track the actual breakdown of each user's messages - + + # Count actual messages for each user type and track their + # composition + user_compositions = {} # Track the actual breakdown of each user's + # messages + for user_name, messages in users.items(): if user_name.startswith('normal_'): total_normal_messages += len(messages) - user_compositions[user_name] = {'normal': len(messages), 'hate': 0, 'sexual': 0} + user_compositions[user_name] = { + 'normal': len(messages), 'hate': 0, 'sexual': 0 + } elif user_name.startswith('hate_'): # Hate users: 10 hate + 5 normal messages hate_count = 10 normal_count = 5 total_hate_messages += hate_count total_normal_messages += normal_count - user_compositions[user_name] = {'normal': normal_count, 'hate': hate_count, 'sexual': 0} + user_compositions[user_name] = { + 'normal': normal_count, 'hate': hate_count, 'sexual': 0 + } elif user_name.startswith('mixed_'): - # Mixed users: 5 hate + 5 normal messages + # Mixed users: 5 hate + 5 normal messages hate_count = 5 normal_count = 5 total_hate_messages += hate_count total_normal_messages += normal_count - user_compositions[user_name] = {'normal': normal_count, 'hate': hate_count, 'sexual': 0} + user_compositions[user_name] = { + 'normal': normal_count, 'hate': hate_count, 'sexual': 0 + } elif user_name.startswith('all_types'): # All types user: 7 hate + 6 normal messages hate_count = 7 normal_count = 6 total_hate_messages += hate_count total_normal_messages += normal_count - user_compositions[user_name] = {'normal': normal_count, 'hate': hate_count, 'sexual': 0} + user_compositions[user_name] = { + 'normal': normal_count, 'hate': hate_count, 'sexual': 0 + } elif user_name.startswith('sexual_'): - # Currently sexual users only have normal messages (sexual content commented out) + # Currently sexual users only have normal messages + # (sexual content commented out) total_normal_messages += len(messages) - user_compositions[user_name] = {'normal': len(messages), 'hate': 0, 'sexual': 0} + user_compositions[user_name] = { + 'normal': len(messages), 'hate': 0, 'sexual': 0 + } total_problematic_messages = total_hate_messages + total_sexual_messages - + # Calculate accurate false positives and true positives - # We need to estimate based on detection rates since we can't track individual message classifications - - # False positives: messages flagged from normal users (definitely all FP) - false_positive_messages_from_normal_users = sum(results[user]['positive_count'] for user in normal_users) - - # For mixed users, we need to estimate FP vs TP based on the composition - # This is an approximation since we don't know which specific messages were flagged + # We need to estimate based on detection rates since we can't track + # individual message classifications + + # False positives: messages flagged from normal users (definitely + # all FP) + false_positive_messages_from_normal_users = sum( + results[user]['positive_count'] for user in normal_users + ) + + # For mixed users, we need to estimate FP vs TP based on the + # composition + # This is an approximation since we don't know which specific + # messages were flagged estimated_false_positives_from_mixed_users = 0 estimated_true_positives = 0 - + for user_name in hate_users + mixed_users + all_types_users: detections = results[user_name]['positive_count'] composition = user_compositions[user_name] - - # Estimate: if detection rate is higher than the hate message ratio, - # some normal messages are likely being flagged too - hate_ratio = composition['hate'] / (composition['hate'] + composition['normal']) - total_messages_for_user = composition['hate'] + composition['normal'] - + + # Estimate: if detection rate is higher than the hate message + # ratio, some normal messages are likely being flagged too + hate_ratio = (composition['hate'] / + (composition['hate'] + composition['normal'])) + total_messages_for_user = (composition['hate'] + + composition['normal']) + # Estimate true positives (capped at actual hate messages) estimated_tp_for_user = min(detections, composition['hate']) estimated_true_positives += estimated_tp_for_user - - # Estimate false positives (excess detections beyond hate messages) + + # Estimate false positives (excess detections beyond hate + # messages) estimated_fp_for_user = max(0, detections - composition['hate']) estimated_false_positives_from_mixed_users += estimated_fp_for_user - + # Total false positives and true positives - total_false_positives = false_positive_messages_from_normal_users + estimated_false_positives_from_mixed_users + total_false_positives = (false_positive_messages_from_normal_users + + estimated_false_positives_from_mixed_users) total_true_positives = estimated_true_positives - + # Corrected rates - false_positive_rate_messages = total_false_positives / total_normal_messages if total_normal_messages > 0 else 0 - true_positive_rate = total_true_positives / total_problematic_messages if total_problematic_messages > 0 else 0 - + false_positive_rate_messages = ( + total_false_positives / total_normal_messages + if total_normal_messages > 0 else 0 + ) + true_positive_rate = ( + total_true_positives / total_problematic_messages + if total_problematic_messages > 0 else 0 + ) + # User-level false positives (for comparison) - false_positive_users = sum(1 for user in normal_users if results[user]['overall_score'] > 0.01) - false_positive_rate_users = false_positive_users / len(normal_users) if normal_users else 0 - + false_positive_users = sum( + 1 for user in normal_users + if results[user]['overall_score'] > 0.01 + ) + false_positive_rate_users = ( + false_positive_users / len(normal_users) + if normal_users else 0 + ) + # Overall accuracy calculation - ensure totals are consistent - total_messages = total_normal_messages + total_problematic_messages # Use calculated totals, not user message counts - - # Fix accuracy calculation - we need to count correctly classified messages vs incorrectly classified - # Alternative approach: Detection accuracy - how well does the system detect problematic content + total_messages = (total_normal_messages + + total_problematic_messages) # Use calculated totals + + # Fix accuracy calculation - we need to count correctly classified + # messages vs incorrectly classified + # Alternative approach: Detection accuracy - how well does the + # system detect problematic content # True Negatives: Normal messages correctly not flagged true_negatives = total_normal_messages - total_false_positives - # False Positives: Normal messages incorrectly flagged + # False Positives: Normal messages incorrectly flagged false_positives = total_false_positives # True Positives: Problematic messages correctly detected true_positives = total_true_positives # False Negatives: Problematic messages missed false_negatives = total_problematic_messages - total_true_positives - + # Verify totals add up correctly and debug if needed - calculated_total = true_positives + true_negatives + false_positives + false_negatives + calculated_total = (true_positives + true_negatives + + false_positives + false_negatives) if threshold == 0.0: # Debug for threshold 0.0 only to avoid spam - print(f" DEBUG - Ratio {ratio}:1 | Total: {total_messages} | TP: {true_positives} | TN: {true_negatives} | FP: {false_positives} | FN: {false_negatives} | Sum: {calculated_total}") - print(f" DEBUG - Normal msgs: {total_normal_messages} | Problematic msgs: {total_problematic_messages}") - - if abs(calculated_total - total_messages) > 0.1: # Allow for small floating point errors - print(f"Warning: Total mismatch - calculated: {calculated_total}, expected: {total_messages}") - + print(f" DEBUG - Ratio {ratio}:1 | Total: {total_messages} | " + f"TP: {true_positives} | TN: {true_negatives} | " + f"FP: {false_positives} | FN: {false_negatives} | " + f"Sum: {calculated_total}") + print(f" DEBUG - Normal msgs: {total_normal_messages} | " + f"Problematic msgs: {total_problematic_messages}") + + if abs(calculated_total - total_messages) > 0.1: # Allow for small + # floating point + # errors + print(f"Warning: Total mismatch - calculated: " + f"{calculated_total}, expected: {total_messages}") + # Standard accuracy formula: (TP + TN) / (TP + TN + FP + FN) - message_accuracy = (true_positives + true_negatives) / total_messages if total_messages > 0 else 0 - + message_accuracy = ((true_positives + true_negatives) / + total_messages if total_messages > 0 else 0) + # Manual verification of accuracy calculation if threshold == 0.0: # Debug for threshold 0.0 only - manual_accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives) - print(f" DEBUG - Accuracy check: {message_accuracy:.3f} vs manual: {manual_accuracy:.3f}") - + manual_accuracy = ((true_positives + true_negatives) / + (true_positives + true_negatives + + false_positives + false_negatives)) + print(f" DEBUG - Accuracy check: {message_accuracy:.3f} " + f"vs manual: {manual_accuracy:.3f}") + # Calculate individual component percentages for verification - tp_rate = true_positives / total_problematic_messages if total_problematic_messages > 0 else 0 - tn_rate = true_negatives / total_normal_messages if total_normal_messages > 0 else 0 - fp_rate = false_positives / total_normal_messages if total_normal_messages > 0 else 0 - fn_rate = false_negatives / total_problematic_messages if total_problematic_messages > 0 else 0 - - # Alternative: Detection effectiveness (focusing on the detection task) + tp_rate = (true_positives / total_problematic_messages + if total_problematic_messages > 0 else 0) + tn_rate = (true_negatives / total_normal_messages + if total_normal_messages > 0 else 0) + fp_rate = (false_positives / total_normal_messages + if total_normal_messages > 0 else 0) + fn_rate = (false_negatives / total_problematic_messages + if total_problematic_messages > 0 else 0) + + # Alternative: Detection effectiveness (focusing on the detection + # task) detection_accuracy = tp_rate # This is the same as true_positive_rate classification_accuracy = tn_rate # This is the true negative rate @@ -381,7 +479,7 @@ def test_thresholds_and_ratios(review_mode: bool = False, results_only: bool = F 'false_positive_rate_users': false_positive_rate_users, 'false_positive_messages': total_false_positives, 'false_positive_users': false_positive_users, - 'true_positive_rate': tp_rate, # Use the calculated rate + 'true_positive_rate': tp_rate, # Use the calculated rate 'true_positive_messages': total_true_positives, 'message_accuracy': message_accuracy, 'detection_accuracy': detection_accuracy, @@ -395,26 +493,35 @@ def test_thresholds_and_ratios(review_mode: bool = False, results_only: bool = F 'false_negatives': false_negatives, # Add breakdown percentages for verification 'tp_rate': tp_rate, - 'tn_rate': tn_rate, + 'tn_rate': tn_rate, 'fp_rate': fp_rate, 'fn_rate': fn_rate }) - - print(f"⏱️ Analysis: {analysis_time:.3f}s | Acc: {message_accuracy:.1%} | TP: {tp_rate:.1%} | FP: {fp_rate:.1%} | TN: {tn_rate:.1%} | FN: {fn_rate:.1%}") - + + print(f"⏱️ Analysis: {analysis_time:.3f}s | " + f"Acc: {message_accuracy:.1%} | TP: {tp_rate:.1%} | " + f"FP: {fp_rate:.1%} | TN: {tn_rate:.1%} | " + f"FN: {fn_rate:.1%}") + # Results-only mode: show just the scores per user if results_only: - print(f"\n🎯 User Rare Class Affinity Scores (Threshold: {threshold}):") - + print(f"\n🎯 User Rare Class Affinity Scores " + f"(Threshold: {threshold}):") + # Group users by category for organized display categories = { - 'Normal Users': [k for k in results.keys() if k.startswith('normal_')], - 'Hate Users': [k for k in results.keys() if k.startswith('hate_')], - 'Sexual Users': [k for k in results.keys() if k.startswith('sexual_')], - 'Mixed Users': [k for k in results.keys() if k.startswith('mixed_')], - 'All Types': [k for k in results.keys() if k.startswith('all_types')] + 'Normal Users': [k for k in results.keys() + if k.startswith('normal_')], + 'Hate Users': [k for k in results.keys() + if k.startswith('hate_')], + 'Sexual Users': [k for k in results.keys() + if k.startswith('sexual_')], + 'Mixed Users': [k for k in results.keys() + if k.startswith('mixed_')], + 'All Types': [k for k in results.keys() + if k.startswith('all_types')] } - + for category, user_list in categories.items(): if not user_list: continue @@ -427,66 +534,82 @@ def test_thresholds_and_ratios(review_mode: bool = False, results_only: bool = F fp_indicator = f" ⚠️ FP ({detections}/15 msgs)" else: fp_indicator = "" - print(f" {user:20}: score={score:.4f} ({detections} detections){fp_indicator}") - + print(f" {user:20}: score={score:.4f} " + f"({detections} detections){fp_indicator}") + continue # Skip the normal summary output - + # Categorize and display results (normal mode) categories = { - 'Normal Users': [k for k in results.keys() if k.startswith('normal_')], - 'Hate Users': [k for k in results.keys() if k.startswith('hate_')], - 'Sexual Users': [k for k in results.keys() if k.startswith('sexual_')], - 'Mixed Users': [k for k in results.keys() if k.startswith('mixed_')], - 'All Types': [k for k in results.keys() if k.startswith('all_types')] + 'Normal Users': [k for k in results.keys() + if k.startswith('normal_')], + 'Hate Users': [k for k in results.keys() + if k.startswith('hate_')], + 'Sexual Users': [k for k in results.keys() + if k.startswith('sexual_')], + 'Mixed Users': [k for k in results.keys() + if k.startswith('mixed_')], + 'All Types': [k for k in results.keys() + if k.startswith('all_types')] } - + for category, user_list in categories.items(): if not user_list: continue - + scores = [results[user]['overall_score'] for user in user_list] - detections = [results[user]['positive_count'] for user in user_list] - - print(f" {category:12}: avg_score={np.mean(scores):.4f}, avg_detections={np.mean(detections):.1f}") - + detections = [results[user]['positive_count'] + for user in user_list] + + print(f" {category:12}: avg_score={np.mean(scores):.4f}, " + f"avg_detections={np.mean(detections):.1f}") + # Summary statistics - normal_scores = [results[u]['overall_score'] for u in categories['Normal Users']] - problematic_scores = [results[u]['overall_score'] for u in - categories['Hate Users'] + categories['Sexual Users'] + - categories['Mixed Users'] + categories['All Types']] - + normal_scores = [results[u]['overall_score'] + for u in categories['Normal Users']] + problematic_scores = [ + results[u]['overall_score'] for u in + categories['Hate Users'] + categories['Sexual Users'] + + categories['Mixed Users'] + categories['All Types'] + ] + if normal_scores and problematic_scores: - separation = np.mean(problematic_scores) - np.mean(normal_scores) + separation = (np.mean(problematic_scores) - + np.mean(normal_scores)) print(f" 📈 Separation: {separation:.4f} (higher is better)") - + # Detailed review mode analysis if review_mode: print("\n🔍 DETAILED REVIEW:") - + # Analyze missed detections (problematic users with low scores) missed_detections = [] false_positives = [] - + for category, user_list in categories.items(): if not user_list: continue - + print(f"\n 📊 {category} Analysis:") - + for user in user_list: user_result = results[user] score = user_result['overall_score'] detections = user_result['positive_count'] max_score = user_result['max_score'] - + # Determine if this should be flagged based on category should_be_flagged = not user.startswith('normal_') - is_flagged = score > 0.1 # Using 0.1 as detection threshold - - status = "✅" if (should_be_flagged == is_flagged) else "❌" - - print(f" {status} {user:15}: score={score:.4f}, detections={detections:2d}, max={max_score:.4f}") - + is_flagged = score > 0.1 # Using 0.1 as detection + # threshold + + status = ("✅" if (should_be_flagged == is_flagged) + else "❌") + + print(f" {status} {user:15}: score={score:.4f}, " + f"detections={detections:2d}, " + f"max={max_score:.4f}") + # Track problematic cases if should_be_flagged and not is_flagged: missed_detections.append({ @@ -507,124 +630,181 @@ def test_thresholds_and_ratios(review_mode: bool = False, results_only: bool = F # Summary of issues if missed_detections: - print(f"\n ⚠️ MISSED DETECTIONS ({len(missed_detections)}):") + print(f"\n ⚠️ MISSED DETECTIONS " + f"({len(missed_detections)}):") for miss in missed_detections: - print(f" - {miss['user']} ({miss['category']}): score={miss['score']:.4f}") - + print(f" - {miss['user']} ({miss['category']}): " + f"score={miss['score']:.4f}") + if false_positives: print(f"\n 🚨 FALSE POSITIVES ({len(false_positives)}):") for fp in false_positives: - print(f" - {fp['user']} ({fp['category']}): score={fp['score']:.4f}") - + print(f" - {fp['user']} ({fp['category']}): " + f"score={fp['score']:.4f}") + if not missed_detections and not false_positives: - print(f"\n 🎯 PERFECT CLASSIFICATION at threshold {threshold}") - + print(f"\n 🎯 PERFECT CLASSIFICATION at threshold " + f"{threshold}") + # Show sample messages for problematic cases if missed_detections and len(missed_detections) <= 2: print(f"\n 📝 Sample messages from missed detections:") for miss in missed_detections[:2]: user_name = miss['user'] user_messages = users[user_name] - print(f"\n {user_name} messages (showing first 5):") + print(f"\n {user_name} messages " + f"(showing first 5):") for i, msg in enumerate(user_messages[:5]): print(f" {i+1}. \"{msg}\"") - + print(f"\n 📊 Classification Summary:") - total_users = len([u for cat in categories.values() for u in cat]) - correct_classifications = total_users - len(missed_detections) - len(false_positives) - accuracy = correct_classifications / total_users if total_users > 0 else 0 - print(f" Accuracy: {accuracy:.2%} ({correct_classifications}/{total_users})") - print(f" Missed: {len(missed_detections)}, False Positives: {len(false_positives)}") + total_users = len([u for cat in categories.values() + for u in cat]) + correct_classifications = (total_users - + len(missed_detections) - + len(false_positives)) + accuracy = (correct_classifications / total_users + if total_users > 0 else 0) + print(f" Accuracy: {accuracy:.2%} " + f"({correct_classifications}/{total_users})") + print(f" Missed: {len(missed_detections)}, " + f"False Positives: {len(false_positives)}") # Performance summary overall_time = time.time() - overall_start_time - + print(f"\n⏱️ PERFORMANCE SUMMARY") print("=" * 50) print(f"Total execution time: {overall_time:.3f}s") - print(f"Total data load time: {total_load_time:.3f}s ({total_load_time/overall_time:.1%})") - print(f"Total analysis time: {total_analysis_time:.3f}s ({total_analysis_time/overall_time:.1%})") - + print(f"Total data load time: {total_load_time:.3f}s " + f"({total_load_time/overall_time:.1%})") + print(f"Total analysis time: {total_analysis_time:.3f}s " + f"({total_analysis_time/overall_time:.1%})") + # Best performance metrics if performance_metrics: print(f"\n📊 OPTIMIZATION METRICS:") - + # Find optimal configurations based on comprehensive metrics - zero_fp_configs = [m for m in performance_metrics if m['fp_rate'] == 0] + zero_fp_configs = [m for m in performance_metrics + if m['fp_rate'] == 0] if zero_fp_configs: # Among zero FP configs, find best true positive rate best_zero_fp = max(zero_fp_configs, key=lambda x: x['tp_rate']) - print(f"Best zero false positive: {best_zero_fp['ratio']}:1 @ {best_zero_fp['threshold']} ({best_zero_fp['analysis_time']:.3f}s, Acc: {best_zero_fp['message_accuracy']:.1%}, TP: {best_zero_fp['tp_rate']:.1%}, FP: {best_zero_fp['fp_rate']:.1%})") - + print(f"Best zero false positive: {best_zero_fp['ratio']}:1 @ " + f"{best_zero_fp['threshold']} " + f"({best_zero_fp['analysis_time']:.3f}s, " + f"Acc: {best_zero_fp['message_accuracy']:.1%}, " + f"TP: {best_zero_fp['tp_rate']:.1%}, " + f"FP: {best_zero_fp['fp_rate']:.1%})") + # Best overall accuracy - now using corrected accuracy - best_accuracy = max(performance_metrics, key=lambda x: x['message_accuracy']) - print(f"Best accuracy: {best_accuracy['ratio']}:1 @ {best_accuracy['threshold']} ({best_accuracy['analysis_time']:.3f}s, Acc: {best_accuracy['message_accuracy']:.1%}, TP: {best_accuracy['tp_rate']:.1%}, FP: {best_accuracy['fp_rate']:.1%})") - + best_accuracy = max(performance_metrics, + key=lambda x: x['message_accuracy']) + print(f"Best accuracy: {best_accuracy['ratio']}:1 @ " + f"{best_accuracy['threshold']} " + f"({best_accuracy['analysis_time']:.3f}s, " + f"Acc: {best_accuracy['message_accuracy']:.1%}, " + f"TP: {best_accuracy['tp_rate']:.1%}, " + f"FP: {best_accuracy['fp_rate']:.1%})") + # Best true positive rate (detection) best_tp = max(performance_metrics, key=lambda x: x['tp_rate']) - print(f"Best true positive: {best_tp['ratio']}:1 @ {best_tp['threshold']} ({best_tp['analysis_time']:.3f}s, Acc: {best_tp['message_accuracy']:.1%}, TP: {best_tp['tp_rate']:.1%}, FP: {best_tp['fp_rate']:.1%})") - + print(f"Best true positive: {best_tp['ratio']}:1 @ " + f"{best_tp['threshold']} ({best_tp['analysis_time']:.3f}s, " + f"Acc: {best_tp['message_accuracy']:.1%}, " + f"TP: {best_tp['tp_rate']:.1%}, " + f"FP: {best_tp['fp_rate']:.1%})") + # Best balance (high accuracy with low FP) - high_accuracy_low_fp = [m for m in performance_metrics if m['fp_rate'] <= 0.02] # ≤2% FP + high_accuracy_low_fp = [m for m in performance_metrics + if m['fp_rate'] <= 0.02] # ≤2% FP if high_accuracy_low_fp: - best_balance = max(high_accuracy_low_fp, key=lambda x: x['message_accuracy']) - print(f"Best balanced: {best_balance['ratio']}:1 @ {best_balance['threshold']} ({best_balance['analysis_time']:.3f}s, Acc: {best_balance['message_accuracy']:.1%}, TP: {best_balance['tp_rate']:.1%}, FP: {best_balance['fp_rate']:.1%})") - + best_balance = max(high_accuracy_low_fp, + key=lambda x: x['message_accuracy']) + print(f"Best balanced: {best_balance['ratio']}:1 @ " + f"{best_balance['threshold']} " + f"({best_balance['analysis_time']:.3f}s, " + f"Acc: {best_balance['message_accuracy']:.1%}, " + f"TP: {best_balance['tp_rate']:.1%}, " + f"FP: {best_balance['fp_rate']:.1%})") + # Fastest analysis - fastest_overall = min(performance_metrics, key=lambda x: x['analysis_time']) - print(f"Fastest analysis: {fastest_overall['ratio']}:1 @ {fastest_overall['threshold']} ({fastest_overall['analysis_time']:.3f}s, Acc: {fastest_overall['message_accuracy']:.1%}, TP: {fastest_overall['tp_rate']:.1%}, FP: {fastest_overall['fp_rate']:.1%})") - + fastest_overall = min(performance_metrics, + key=lambda x: x['analysis_time']) + print(f"Fastest analysis: {fastest_overall['ratio']}:1 @ " + f"{fastest_overall['threshold']} " + f"({fastest_overall['analysis_time']:.3f}s, " + f"Acc: {fastest_overall['message_accuracy']:.1%}, " + f"TP: {fastest_overall['tp_rate']:.1%}, " + f"FP: {fastest_overall['fp_rate']:.1%})") + # Cache performance summary final_cache_info = get_cache_info() print(f"\n🔥 CACHE PERFORMANCE:") - print(f"Models cached: {final_cache_info['cache_size']} ({final_cache_info['cached_models']})") - print(f"Cache benefit: Subsequent model loads are ~1000x faster than initial load") + print(f"Models cached: {final_cache_info['cache_size']} " + f"({final_cache_info['cached_models']})") + print(f"Cache benefit: Subsequent model loads are ~1000x faster " + f"than initial load") print(f"Memory optimization: {final_cache_info['memory_info']}") - + # Performance by ratio with message-level metrics print(f"\nMessage-level performance by ratio:") for ratio in ratios_to_test: - ratio_metrics = [m for m in performance_metrics if m['ratio'] == ratio] + ratio_metrics = [m for m in performance_metrics + if m['ratio'] == ratio] avg_time = np.mean([m['analysis_time'] for m in ratio_metrics]) - avg_accuracy = np.mean([m['message_accuracy'] for m in ratio_metrics]) + avg_accuracy = np.mean([m['message_accuracy'] + for m in ratio_metrics]) avg_tp_rate = np.mean([m['tp_rate'] for m in ratio_metrics]) avg_fp_rate = np.mean([m['fp_rate'] for m in ratio_metrics]) - print(f" {ratio}:1 ratio: {avg_time:.3f}s avg | Acc: {avg_accuracy:.1%} | TP: {avg_tp_rate:.1%} | FP: {avg_fp_rate:.1%}") + print(f" {ratio}:1 ratio: {avg_time:.3f}s avg | " + f"Acc: {avg_accuracy:.1%} | TP: {avg_tp_rate:.1%} | " + f"FP: {avg_fp_rate:.1%}") + def main(): """Run the comprehensive testing.""" import sys - + # Check for flags review_mode = '--review' in sys.argv or '-r' in sys.argv results_only = '--results-only' in sys.argv or '--results' in sys.argv - + # Ensure mutually exclusive modes if review_mode and results_only: - print("❌ Error: Cannot use both --review and --results-only flags simultaneously") + print("❌ Error: Cannot use both --review and --results-only flags " + "simultaneously") sys.exit(1) - + if review_mode: print("🔍 Review mode enabled - showing detailed analysis") elif results_only: - print("📊 Results-only mode enabled - showing rare class affinity scores per user") - + print("📊 Results-only mode enabled - showing rare class affinity " + "scores per user") + # Set random seed for reproducible results np.random.seed(42) - - test_thresholds_and_ratios(review_mode=review_mode, results_only=results_only) - + + test_thresholds_and_ratios(review_mode=review_mode, + results_only=results_only) + # Final cache cleanup and summary final_cache_info = get_cache_info() if final_cache_info['cache_size'] > 0: - print(f"\n🧹 Cache cleanup: {final_cache_info['cache_size']} models in cache") + print(f"\n🧹 Cache cleanup: {final_cache_info['cache_size']} " + f"models in cache") clear_model_cache() print(f"✅ Cache cleared successfully") - - print(f"\n✅ Testing complete! Check results above to determine optimal threshold and ratio.") + + print(f"\n✅ Testing complete! Check results above to determine " + f"optimal threshold and ratio.") if not review_mode and not results_only: - print("💡 Tip: Run with --review (-r) for detailed analysis or --results-only for user scores only") - print("🔥 Performance: Model caching enabled - repeated runs will be faster!") + print("💡 Tip: Run with --review (-r) for detailed analysis or " + "--results-only for user scores only") + print("🔥 Performance: Model caching enabled - repeated runs " + "will be faster!") + if __name__ == "__main__": main() diff --git a/src/sentinel/embeddings/sbert.py b/src/sentinel/embeddings/sbert.py index 6667df1..df23ee5 100644 --- a/src/sentinel/embeddings/sbert.py +++ b/src/sentinel/embeddings/sbert.py @@ -51,7 +51,7 @@ def get_sentence_transformer_and_scaling_fn( - A scaling function for similarity scores if needed (only for E5 family models), or None """ global _model_cache - + # Check cache first if caching is enabled if use_cache and sentence_model_name_or_path in _model_cache: LOG.debug(f"Loading cached SentenceTransformer model: {sentence_model_name_or_path}") @@ -59,7 +59,7 @@ def get_sentence_transformer_and_scaling_fn( else: LOG.debug(f"Creating new SentenceTransformer model: {sentence_model_name_or_path}") model = SentenceTransformer(sentence_model_name_or_path) - + # Cache the model if caching is enabled if use_cache: _model_cache[sentence_model_name_or_path] = model @@ -95,7 +95,7 @@ def get_cache_info() -> dict: return { "cached_models": list(_model_cache.keys()), "cache_size": len(_model_cache), - "memory_info": "Use clear_model_cache() to free memory if needed" + "memory_info": "Use clear_model_cache() to free memory if needed", } From 1505654e1ca342fb1f643369f5170249f9c972a3 Mon Sep 17 00:00:00 2001 From: rafainn Date: Tue, 19 Aug 2025 10:42:02 +0100 Subject: [PATCH 11/21] STYLE: Apply PEP 8 formatting and update examples - Apply PEP 8 formatting to Example_Threshold_Script.py - Update embeddings.safetensors - Update sentinel_against_hate.ipynb - Fixed line length violations (max 79 characters) - Corrected indentation and spacing - Enhanced readability while maintaining functionality --- examples/Example_Threshold_Script.py | 2 +- examples/sentinel_against_hate.ipynb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/Example_Threshold_Script.py b/examples/Example_Threshold_Script.py index ccbe6cc..75e98ff 100644 --- a/examples/Example_Threshold_Script.py +++ b/examples/Example_Threshold_Script.py @@ -214,7 +214,7 @@ def test_thresholds_and_ratios(review_mode: bool = False, # Time data loading load_start = time.time() index = SentinelLocalIndex.load( - path="./examples/hate_speech_model", + path="path/to/local/index", negative_to_positive_ratio=ratio ) load_time = time.time() - load_start diff --git a/examples/sentinel_against_hate.ipynb b/examples/sentinel_against_hate.ipynb index 5985866..bdf9338 100644 --- a/examples/sentinel_against_hate.ipynb +++ b/examples/sentinel_against_hate.ipynb @@ -2127,7 +2127,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "74e25bf5", "metadata": {}, "outputs": [ From 7ee49f841a7088528a0bd82267e0642b9b8e042a Mon Sep 17 00:00:00 2001 From: rafainn Date: Tue, 19 Aug 2025 11:32:59 +0100 Subject: [PATCH 12/21] FEAT: Add caching option to SentinelLocalIndex for improved load performance --- examples/Example_Threshold_Script.py | 8 +++++--- src/sentinel/sentinel_local_index.py | 9 ++++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/Example_Threshold_Script.py b/examples/Example_Threshold_Script.py index 75e98ff..87fb474 100644 --- a/examples/Example_Threshold_Script.py +++ b/examples/Example_Threshold_Script.py @@ -215,7 +215,8 @@ def test_thresholds_and_ratios(review_mode: bool = False, load_start = time.time() index = SentinelLocalIndex.load( path="path/to/local/index", - negative_to_positive_ratio=ratio + negative_to_positive_ratio=ratio, + Cache_Model=True ) load_time = time.time() - load_start total_load_time += load_time @@ -262,7 +263,7 @@ def test_thresholds_and_ratios(review_mode: bool = False, for user_name, messages in users.items(): result = index.calculate_rare_class_affinity( messages, - min_score_to_consider=threshold + min_score_to_consider=threshold, ) # Calculate statistics @@ -475,7 +476,8 @@ def test_thresholds_and_ratios(review_mode: bool = False, 'threshold': threshold, 'load_time': load_time, 'analysis_time': analysis_time, - 'false_positive_rate_messages': fp_rate, # Use the calculated rate + 'false_positive_rate_messages': fp_rate, # Use the calculated + # rate 'false_positive_rate_users': false_positive_rate_users, 'false_positive_messages': total_false_positives, 'false_positive_users': false_positive_users, diff --git a/src/sentinel/sentinel_local_index.py b/src/sentinel/sentinel_local_index.py index 6dda29d..9550cac 100644 --- a/src/sentinel/sentinel_local_index.py +++ b/src/sentinel/sentinel_local_index.py @@ -176,6 +176,7 @@ def load( aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, negative_to_positive_ratio: Optional[float] = 5.0, + Cache_Model: bool = True, ) -> "SentinelLocalIndex": """ Load the index from a path and returns a new SentinelLocalIndex instance. @@ -188,6 +189,7 @@ def load( If None, preserves the original ratio from the saved index. If 5.0 (default), uses a 5:1 negative to positive ratio for optimal performance. If specified, downsamples negative examples to achieve the desired ratio. + Cache_Model: Whether to use model caching for faster subsequent loads. Default True. Returns: A new SentinelLocalIndex instance with the loaded model and embeddings. @@ -205,7 +207,10 @@ def load( # Create the sentence model and get the scaling function model_name = config.encoder_model_name_or_path - sentence_model, scale_fn = get_sentence_transformer_and_scaling_fn(model_name) + sentence_model, scale_fn = get_sentence_transformer_and_scaling_fn( + model_name, + use_cache = Cache_Model + ) # Create a new instance with the loaded model and data instance = cls( @@ -304,6 +309,7 @@ def calculate_rare_class_affinity( # Use when simulating by sampling texts from the same data indexed. prevent_exact_match: bool = False, encoding_additional_kwargs: Mapping[str, Any] = {}, + Cache_Model: bool = True, show_progress_bar: bool = False, ) -> RareClassAffinityResult: """Calculate rare class affinity for the given text samples in realtime. @@ -324,6 +330,7 @@ def calculate_rare_class_affinity( min_score_to_consider: Threshold below which scores are set to 0. prevent_exact_match: Whether to skip exact matches when scoring. encoding_additional_kwargs: Additional keyword arguments for encoding. + Cache_Model: Whether to use model caching (currently not used in this method). show_progress_bar: Whether to display a progress bar during encoding. Returns: From b638046efc14bcd2fae3bdd17ae4a81e0e746d83 Mon Sep 17 00:00:00 2001 From: rafainn Date: Tue, 19 Aug 2025 18:59:48 +0100 Subject: [PATCH 13/21] REFACTOR: Refactor user profile creation to use external test data for speech examples --- examples/Example_Threshold_Script.py | 99 +++++-------------- examples/test_data.py | 140 +++++++++++++++++++++++++++ 2 files changed, 163 insertions(+), 76 deletions(-) create mode 100644 examples/test_data.py diff --git a/examples/Example_Threshold_Script.py b/examples/Example_Threshold_Script.py index 87fb474..f8f35ee 100644 --- a/examples/Example_Threshold_Script.py +++ b/examples/Example_Threshold_Script.py @@ -28,6 +28,7 @@ from sentinel.sentinel_local_index import SentinelLocalIndex from sentinel.embeddings.sbert import clear_model_cache, get_cache_info +from test_data import NORMAL_SPEECH, HATE_SPEECH, SEXUAL_CONTENT import numpy as np import time from typing import Dict, List, Tuple @@ -36,48 +37,10 @@ def create_user_profiles() -> Dict[str, List[str]]: """Create 10 different user profiles with varying speech patterns.""" - # Normal speech examples - normal_speech = [ - "Hey how are you doing today?", - "What are you up to this weekend?", - "I love this new video game", - "The weather is really nice today", - "Did you see that movie last night?", - "I'm going to the store later", - "Work was pretty busy today", - "Let's grab lunch sometime", - "That's a really cool picture", - "I just finished reading a great book", - "How was your vacation?", - "I'm excited for the weekend", - "Thanks for helping me out", - "Have a great day!", - "What's your favorite food?", - "I'm learning to cook pasta", - "The concert was amazing", - "I need to go shopping", - "My family is visiting next week", - "That joke was hilarious" - ] - - # Hate speech examples - hate_speech = [ - "I hate all those people", - "they don't belong here", - "those immigrants are ruining everything", - "whites are superior to everyone else", - "the holocaust never happened", - "Hitler had the right idea", - "Jews control all the banks", - "Black people are all criminals", - "Muslims are all terrorists", - "gay people are disgusting", - "women belong in the kitchen", - "trans people are mentally ill", - "minorities are destroying our country", - "we need to get rid of them all", - "violence is the only solution" - ] + # Import example data from test_data module + normal_speech = NORMAL_SPEECH # 50 examples + hate_speech = HATE_SPEECH # 50 examples + # Sexual content examples @@ -87,23 +50,7 @@ def create_user_profiles() -> Dict[str, List[str]]: As such information relating to sexual content has been removed from the codebase but can be added back by removing notes. - sexual_content = [ - "want to see me naked?", - "send me some hot pics", - "let's hook up tonight", - "I'm so horny right now", - "what are you wearing?", - "wanna have some fun?", - "you look so sexy", - "let's get physical", - "I want to touch you", - "meet me in private", - "show me your body", - "let's talk dirty", - "I'm thinking about you", - "we could be alone together", - "send nudes" - ] + sexual_content = SEXUAL_SPEECH """ users = {} @@ -111,13 +58,13 @@ def create_user_profiles() -> Dict[str, List[str]]: # Normal Speech Only Users (3 users) for i in range(1, 4): users[f"normal_user_{i}"] = np.random.choice( - normal_speech, size=15, replace=False + normal_speech, size=35, replace=False ).tolist() # Hate Speech Focused Users (2 users) for i in range(1, 3): hate_msgs = np.random.choice( - hate_speech, size=10, replace=False + hate_speech, size=30, replace=False ).tolist() normal_msgs = np.random.choice( normal_speech, size=5, replace=False @@ -126,36 +73,36 @@ def create_user_profiles() -> Dict[str, List[str]]: np.random.shuffle(users[f"hate_user_{i}"]) # Sexual Content Focused Users (2 users) - for i in range(1, 3): - # sexual_msgs = np.random.choice(sexual_content, size=10, - # replace=False).tolist() - # Requires an index with positive sexual content examples - normal_msgs = np.random.choice( - normal_speech, size=5, replace=False - ).tolist() - users[f"sexual_user_{i}"] = normal_msgs # + sexual_msgs - np.random.shuffle(users[f"sexual_user_{i}"]) + # for i in range(1, 3): + # sexual_msgs = np.random.choice(sexual_content, size=10, + # replace=False).tolist() + # Requires an index with positive sexual content examples + # normal_msgs = np.random.choice( + # normal_speech, size=5, replace=False + # ).tolist() + # users[f"sexual_user_{i}"] = normal_msgs # + sexual_msgs + # np.random.shuffle(users[f"sexual_user_{i}"]) # Mixed Content Users (2 users) for i in range(1, 3): hate_msgs = np.random.choice( - hate_speech, size=5, replace=False + hate_speech, size=20, replace=False ).tolist() # sexual_msgs = np.random.choice(sexual_content, size=5, # replace=False).tolist() # Requires an index with positive sexual content examples normal_msgs = np.random.choice( - normal_speech, size=5, replace=False + normal_speech, size=15, replace=False ).tolist() users[f"mixed_user_{i}"] = hate_msgs + normal_msgs # + sexual_msgs np.random.shuffle(users[f"mixed_user_{i}"]) # All Types Combined User (1 user) - hate_msgs = np.random.choice(hate_speech, size=7, replace=False).tolist() - # sexual_msgs = np.random.choice(sexual_content, size=7, + hate_msgs = np.random.choice(hate_speech, size=10, replace=False).tolist() + # sexual_msgs = np.random.choice(sexual_content, size=10, # replace=False).tolist() # Requires an index with positive sexual content examples - normal_msgs = np.random.choice(normal_speech, size=6, replace=False).tolist() + normal_msgs = np.random.choice(normal_speech, size=10, replace=False).tolist() users["all_types_user"] = hate_msgs + normal_msgs # + sexual_msgs np.random.shuffle(users["all_types_user"]) @@ -214,7 +161,7 @@ def test_thresholds_and_ratios(review_mode: bool = False, # Time data loading load_start = time.time() index = SentinelLocalIndex.load( - path="path/to/local/index", + path="./examples/hate_speech_model", negative_to_positive_ratio=ratio, Cache_Model=True ) diff --git a/examples/test_data.py b/examples/test_data.py new file mode 100644 index 0000000..30297ee --- /dev/null +++ b/examples/test_data.py @@ -0,0 +1,140 @@ +""" +Test data for Sentinel hate speech detection testing. + +This module contains example text samples used for testing different +user behavior patterns and model performance evaluation. +""" + +# Normal speech examples (50 examples) +NORMAL_SPEECH = [ + "Hey how are you doing today?", + "What are you up to this weekend?", + "I love this new video game", + "The weather is really nice today", + "Did you see that movie last night?", + "I'm going to the store later", + "Work was pretty busy today", + "Let's grab lunch sometime", + "That's a really cool picture", + "I just finished reading a great book", + "How was your vacation?", + "I'm excited for the weekend", + "Thanks for helping me out", + "Have a great day!", + "What's your favorite food?", + "I'm learning to cook pasta", + "The concert was amazing", + "I need to go shopping", + "My family is visiting next week", + "That joke was hilarious", + "Good morning everyone!", + "Hope you have a wonderful day", + "The traffic is terrible today", + "I can't wait for summer vacation", + "My dog learned a new trick", + "This coffee tastes really good", + "I'm thinking about changing my major", + "The new restaurant downtown is excellent", + "Can you help me with this project?", + "I love spending time with friends", + "The sunset looks beautiful tonight", + "I'm so tired after that workout", + "What time does the meeting start?", + "I need to call my parents later", + "This song always makes me happy", + "The library closes at 9 PM", + "I'm planning a surprise party", + "The game was really exciting", + "I should probably get some sleep", + "Thanks for the recommendation!", + "The presentation went really well", + "I have a test tomorrow morning", + "Our team won the championship", + "I'm working on a new project", + "The deadline is next Friday", + "Class was cancelled today", + "I got a promotion at work", + "The professor explained it clearly", + "We're having a team building event", + "I need to finish my homework" +] + +# Hate speech examples (50 examples) +HATE_SPEECH = [ + "I hate all those people", + "they don't belong here", + "those immigrants are ruining everything", + "whites are superior to everyone else", + "the holocaust never happened", + "Hitler had the right idea", + "Jews control all the banks", + "Black people are all criminals", + "Muslims are all terrorists", + "gay people are disgusting", + "women belong in the kitchen", + "trans people are mentally ill", + "minorities are destroying our country", + "we need to get rid of them all", + "violence is the only solution", + "send them back where they came from", + "these people are inferior", + "our race is under attack", + "foreigners are taking our jobs", + "they're destroying our culture", + "we need to protect our bloodline", + "those degenerates deserve punishment", + "purify our society", + "ethnic cleansing is necessary", + "they're subhuman scum", + "women are naturally inferior", + "homosexuality is a disease", + "immigrants bring crime and drugs", + "they're invading our homeland", + "death to all enemies", + "burn down their neighborhoods", + "exterminate the vermin", + "racial purity must be maintained", + "they're parasites on society", + "violence against them is justified", + "eliminate the weak", + "crush the opposition", + "our people are the chosen ones", + "they pollute our gene pool", + "terrorism against them is acceptable", + "wipe them off the face of earth", + "they're a cancer on humanity", + "breed them out of existence", + "no mercy for the inferior", + "cleanse our nation", + "they deserve to suffer", + "hatred is the only answer", + "destroy their way of life", + "they're animals not humans", + "final solution needed" +] + +""" +Requires an index with sexual positive examples, current database only +focuses on hate-speech. +As such information relating to sexual content has been removed from the +codebase but can be added back by removing notes. +""" + +SEXUAL_CONTENT = [ + "want to see me naked?", + "send me some hot pics", + "let's hook up tonight", + "I'm so horny right now", + "what are you wearing?", + "wanna have some fun?", + "you look so sexy", + "let's get physical", + "I want to touch you", + "meet me in private", + "show me your body", + "let's talk dirty", + "I'm thinking about you", + "we could be alone together", + "send nudes", +] + From e3277f18e7e72d95544cc02f4597ccec711706c2 Mon Sep 17 00:00:00 2001 From: rafainn Date: Tue, 19 Aug 2025 19:00:28 +0100 Subject: [PATCH 14/21] TEST: Update mean_of_positives tests to return 0.0 for negative and empty score arrays, fixes edge case, NaN returns. --- src/sentinel/score_formulae.py | 7 ++++++- tests/test_score_formulae.py | 23 +++++++++-------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/sentinel/score_formulae.py b/src/sentinel/score_formulae.py index 90e709f..9b584e3 100644 --- a/src/sentinel/score_formulae.py +++ b/src/sentinel/score_formulae.py @@ -41,7 +41,12 @@ def mean_of_positives(scores: np.array) -> float: Returns: Mean of positive scores, indicating overall affinity to rare class content """ - return np.mean(scores[scores > 0]) + if scores.size == 0: + return 0.0 + positives = scores[scores > 0] + if positives.size == 0: + return 0.0 + return float(np.mean(positives)) def skewness(scores: np.array, min_size_of_scores: int = 10) -> float: diff --git a/tests/test_score_formulae.py b/tests/test_score_formulae.py index 214124a..1372e23 100644 --- a/tests/test_score_formulae.py +++ b/tests/test_score_formulae.py @@ -74,20 +74,15 @@ def test_mean_of_positives(): result = mean_of_positives(scores) assert result == 0.6, "Should ignore negative scores and return mean of positives" - # Test with all negative scores - will raise a RuntimeWarning, but we're checking that - # we handle the "mean of empty slice" case gracefully - with pytest.warns(RuntimeWarning): - scores = np.array([-0.5, -0.3, -0.7]) - result = mean_of_positives(scores) - # When there's no positive scores, numpy returns NaN for an empty slice - assert np.isnan( - result - ), "Should return NaN when there are no positive scores" # Test with empty array - will raise a RuntimeWarning - with pytest.warns(RuntimeWarning): - scores = np.array([]) - result = mean_of_positives(scores) - # When there's an empty array, numpy returns NaN - assert np.isnan(result), "Should return NaN for empty array" + # Test with all negative scores - should return 0.0 like other functions + scores = np.array([-0.5, -0.3, -0.7]) + result = mean_of_positives(scores) + assert result == 0.0, "Should return 0.0 when there are no positive scores" + + # Test with empty array - should return 0.0 like other functions + scores = np.array([]) + result = mean_of_positives(scores) + assert result == 0.0, "Should return 0.0 for empty array" def test_skewness(): From 1b6a2ee0982aa3ee44a843b975944d1b54ae8d24 Mon Sep 17 00:00:00 2001 From: rafainn Date: Tue, 19 Aug 2025 19:05:23 +0100 Subject: [PATCH 15/21] CHORE: Remove redundant Cache_Model variable in `calculate_rare_class_affinity`, update example file to use path/to/index rather than local path --- examples/Example_Threshold_Script.py | 2 +- src/sentinel/sentinel_local_index.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/Example_Threshold_Script.py b/examples/Example_Threshold_Script.py index f8f35ee..8ac60bb 100644 --- a/examples/Example_Threshold_Script.py +++ b/examples/Example_Threshold_Script.py @@ -161,7 +161,7 @@ def test_thresholds_and_ratios(review_mode: bool = False, # Time data loading load_start = time.time() index = SentinelLocalIndex.load( - path="./examples/hate_speech_model", + path="path/to/index", negative_to_positive_ratio=ratio, Cache_Model=True ) diff --git a/src/sentinel/sentinel_local_index.py b/src/sentinel/sentinel_local_index.py index 9550cac..3140bfc 100644 --- a/src/sentinel/sentinel_local_index.py +++ b/src/sentinel/sentinel_local_index.py @@ -309,7 +309,6 @@ def calculate_rare_class_affinity( # Use when simulating by sampling texts from the same data indexed. prevent_exact_match: bool = False, encoding_additional_kwargs: Mapping[str, Any] = {}, - Cache_Model: bool = True, show_progress_bar: bool = False, ) -> RareClassAffinityResult: """Calculate rare class affinity for the given text samples in realtime. @@ -330,7 +329,6 @@ def calculate_rare_class_affinity( min_score_to_consider: Threshold below which scores are set to 0. prevent_exact_match: Whether to skip exact matches when scoring. encoding_additional_kwargs: Additional keyword arguments for encoding. - Cache_Model: Whether to use model caching (currently not used in this method). show_progress_bar: Whether to display a progress bar during encoding. Returns: From d63b8ce6706dbf5481e94f1154b7abbe9b84cc22 Mon Sep 17 00:00:00 2001 From: rafainn Date: Tue, 19 Aug 2025 21:59:02 +0100 Subject: [PATCH 16/21] TEST: Add edge case tests for aggregation functions and contrastive_components in score_formulae and SentinelLocalIndex --- tests/test_score_formulae.py | 101 ++++++++++++ tests/test_sriracha_local_index.py | 248 +++++++++++++++++++++++++++++ 2 files changed, 349 insertions(+) diff --git a/tests/test_score_formulae.py b/tests/test_score_formulae.py index 1372e23..7e1de62 100644 --- a/tests/test_score_formulae.py +++ b/tests/test_score_formulae.py @@ -25,6 +25,7 @@ percentile_score, softmax_weighted_mean, max_score, + contrastive_components, ) @@ -145,3 +146,103 @@ def test_additional_aggregators(): # max_score val = max_score(scores) assert np.isclose(val, 1.0) + + +def test_aggregation_functions_edge_cases(): + """Test edge cases for aggregation functions to improve coverage.""" + # Test top_k_mean with empty array (lines 98, 101) + empty_scores = np.array([]) + assert top_k_mean(empty_scores) == 0.0 + + # Test top_k_mean with no positive scores + negative_scores = np.array([-1, -2, -3]) + assert top_k_mean(negative_scores) == 0.0 + + # Test percentile_score with empty array (lines 119, 122) + assert percentile_score(empty_scores) == 0.0 + + # Test percentile_score with no positive scores + assert percentile_score(negative_scores) == 0.0 + + # Test softmax_weighted_mean with empty array (lines 139, 142) + assert softmax_weighted_mean(empty_scores) == 0.0 + + # Test softmax_weighted_mean with no positive scores + assert softmax_weighted_mean(negative_scores) == 0.0 + + # Test max_score with empty array (lines 155, 158) + assert max_score(empty_scores) == 0.0 + + # Test max_score with no positive scores + assert max_score(negative_scores) == 0.0 + + +def test_contrastive_components_edge_cases(): + """Test edge cases for contrastive_components function.""" + # Test with divide by zero scenario (line 176) + # This is hard to trigger since we use exp(), but we can test normal operation + pos_sims = [0.5, 0.6] + neg_sims = [0.1, 0.2] + + pos_term, neg_term, log_ratio = contrastive_components(pos_sims, neg_sims) + + assert pos_term > 0 + assert neg_term > 0 + assert log_ratio != 0 + + # Test when log_ratio would be infinity (line 188) + # This tests the inf handling in contrastive_components + very_high_pos = [10.0, 10.0] # Very high similarities + very_low_neg = [-10.0, -10.0] # Very low similarities + + pos_term, neg_term, log_ratio = contrastive_components(very_high_pos, very_low_neg) + + assert pos_term > neg_term + assert log_ratio > 0 + + +class TestScoreFormulaeEdgeCases: + """Edge case tests for score formulae functions to improve coverage.""" + + def test_aggregation_functions_empty_arrays(self): + """Test aggregation functions with empty arrays.""" + empty_array = np.array([]) + + # Test mean_of_positives with empty array (line 98) + result = mean_of_positives(empty_array) + assert result == 0.0 + + # Test top_k_mean with empty array (line 119) + result = top_k_mean(empty_array, k=3) + assert result == 0.0 + + # Test top_k_mean with k larger than array size (line 122) + small_array = np.array([0.5]) + result = top_k_mean(small_array, k=3) + assert np.isclose(result, 0.5) + + # Test percentile_score with empty array (line 139) + result = percentile_score(empty_array, q=50) + assert result == 0.0 + + # Test percentile_score with all negative values (line 142) + negative_array = np.array([-1.0, -0.5, -2.0]) + result = percentile_score(negative_array, q=75) + assert result == 0.0 + + # Test skewness with empty array (line 155) + result = skewness(empty_array) + assert np.isclose(result, 0.0) + + # Test skewness with single value (line 158) + single_value = np.array([0.5]) + result = skewness(single_value) + assert np.isclose(result, 0.0) + + # Test softmax_weighted_mean with empty array + result = softmax_weighted_mean(empty_array) + assert result == 0.0 + + # Test max_score with empty array + result = max_score(empty_array) + assert result == 0.0 diff --git a/tests/test_sriracha_local_index.py b/tests/test_sriracha_local_index.py index 7e95f7d..57b9009 100644 --- a/tests/test_sriracha_local_index.py +++ b/tests/test_sriracha_local_index.py @@ -255,3 +255,251 @@ def test_end_to_end_workflow(): # Negative examples should be zero assert negative_score == 0, "Negative example should score zero" + + +class TestSentinelLocalIndexEdgeCases: + """Test edge cases and error handling in SentinelLocalIndex.""" + + def test_apply_negative_ratio_with_none(self, simple_index): + """Test _apply_negative_ratio with None value (preserve original ratio).""" + original_negative_size = simple_index.negative_embeddings.shape[0] + original_positive_size = simple_index.positive_embeddings.shape[0] + + # Test with None - should preserve original ratio and log info + simple_index._apply_negative_ratio(None) + + # Should remain unchanged + assert simple_index.negative_embeddings.shape[0] == original_negative_size + assert simple_index.positive_embeddings.shape[0] == original_positive_size + + def test_apply_negative_ratio_with_null_embeddings(self): + """Test _apply_negative_ratio with null embeddings.""" + # Create index with null embeddings + index = SentinelLocalIndex( + sentence_model=None, + positive_embeddings=None, + negative_embeddings=None, + scale_fn=None, + positive_corpus=None, + negative_corpus=None + ) + + # Should handle null embeddings gracefully + index._apply_negative_ratio(1.0) + + # Should remain None + assert index.positive_embeddings is None + assert index.negative_embeddings is None + + def test_apply_negative_ratio_with_empty_embeddings(self): + """Test _apply_negative_ratio with empty embeddings.""" + # Create empty tensors + empty_positive = torch.tensor([]).reshape(0, 384) # 0 samples, 384 dimensions + empty_negative = torch.tensor([]).reshape(0, 384) + + index = SentinelLocalIndex( + sentence_model=None, + positive_embeddings=empty_positive, + negative_embeddings=empty_negative, + scale_fn=None, + positive_corpus=[], + negative_corpus=[] + ) + + # Should handle empty embeddings gracefully + index._apply_negative_ratio(1.0) + + # Should remain empty + assert index.positive_embeddings.shape[0] == 0 + assert index.negative_embeddings.shape[0] == 0 + + def test_apply_negative_ratio_with_invalid_ratio(self, simple_index): + """Test _apply_negative_ratio with invalid ratio values.""" + original_negative_size = simple_index.negative_embeddings.shape[0] + + # Test with negative ratio + simple_index._apply_negative_ratio(-1.0) + assert simple_index.negative_embeddings.shape[0] == original_negative_size + + # Test with zero ratio + simple_index._apply_negative_ratio(0.0) + assert simple_index.negative_embeddings.shape[0] == original_negative_size + + def test_apply_negative_ratio_calculation_error(self, simple_index): + """Test _apply_negative_ratio with calculation that would cause overflow.""" + # Test with extremely large ratio that could cause overflow + import sys + simple_index._apply_negative_ratio(float(sys.maxsize)) + + # Should handle gracefully and preserve original embeddings + assert simple_index.negative_embeddings is not None + + def test_calculate_rare_class_affinity_with_prevent_exact_match(self, simple_index): + """Test calculate_rare_class_affinity with prevent_exact_match=True.""" + # Use text that might create exact matches + observations = ["unsafe behavior detected", "harmful content identified"] # These are in positive corpus + + result = simple_index.calculate_rare_class_affinity( + observations, + prevent_exact_match=True + ) + + assert isinstance(result, RareClassAffinityResult) + assert result.rare_class_affinity_score >= 0 + + def test_calculate_rare_class_affinity_with_high_threshold(self, simple_index): + """Test calculate_rare_class_affinity with very high threshold to trigger empty scores.""" + observations = ["some neutral text that won't match well"] + + result = simple_index.calculate_rare_class_affinity( + observations, + min_score_to_consider=100.0 # Extremely high threshold to ensure no scores pass + ) + + assert isinstance(result, RareClassAffinityResult) + # Should be 0.0 due to high threshold filtering out all scores + assert result.rare_class_affinity_score == 0.0 + + def test_apply_negative_ratio_zero_calculated_samples(self, simple_index): + """Test _apply_negative_ratio when calculated samples to keep is zero.""" + # Use a very small ratio that would result in 0 samples + simple_index._apply_negative_ratio(0.001) # Should result in 0 samples for typical test data + + # Should preserve original embeddings due to invalid calculated value + assert simple_index.negative_embeddings.shape[0] >= 0 + + def test_apply_negative_ratio_calculation_overflow(self, simple_index): + """Test _apply_negative_ratio with values that cause calculation errors.""" + # Test with float('inf') to trigger calculation errors + simple_index._apply_negative_ratio(float('inf')) + + # Should preserve original embeddings + assert simple_index.negative_embeddings is not None + + def test_torch_operation_error_handling(self, simple_index): + """Test error handling in torch operations during downsampling.""" + # This is harder to trigger directly, but we can test with edge case ratios + original_size = simple_index.negative_embeddings.shape[0] + + # Test with various edge case ratios + simple_index._apply_negative_ratio(0.1) # Very small ratio + + # Should complete without errors + assert simple_index.negative_embeddings.shape[0] <= original_size + + def test_debug_logging_and_neighbor_recording(self, simple_index, caplog): + """Test debug logging paths and neighbor recording.""" + import logging + caplog.set_level(logging.DEBUG) + + # Test with text that will generate debug output + observations = ["test observation for debug output"] + result = simple_index.calculate_rare_class_affinity(observations) + + assert isinstance(result, RareClassAffinityResult) + # Verify that debug logging occurred (neighbor records are always created) + + def test_torch_downsampling_runtime_error(self): + """Test RuntimeError handling during torch tensor downsampling.""" + import unittest.mock + + # Create a simple index with mock sentence transformer + from unittest.mock import MagicMock + mock_model = MagicMock() + mock_model.encode.return_value = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) + + # Create embeddings that will cause issues during downsampling + positive_embeddings = torch.tensor([[1.0, 0.0, 0.0]]) + negative_embeddings = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 1.0, 0.0]]) + + index = SentinelLocalIndex( + sentence_model=mock_model, + positive_embeddings=positive_embeddings, + negative_embeddings=negative_embeddings + ) + + # Mock torch.randperm to raise RuntimeError to trigger the exception handling (lines 290-292) + with unittest.mock.patch('torch.randperm', side_effect=RuntimeError("Mocked torch error")): + # This should trigger the exception handling in _apply_negative_ratio + original_size = index.negative_embeddings.shape[0] + index._apply_negative_ratio(0.5) # Try to reduce size + + # Should preserve original embeddings due to the error + assert index.negative_embeddings.shape[0] == original_size + + def test_exact_match_compensation_line_410(self, simple_index): + """Test the exact match compensation code path (line 410).""" + # Create observations that are very similar to corpus content to trigger exact matches + observations = ["unsafe behavior detected"] # This should be very close to positive corpus + + result = simple_index.calculate_rare_class_affinity( + observations, + prevent_exact_match=True, + top_k=1 # Small top_k to increase chance of exact matches + ) + + assert isinstance(result, RareClassAffinityResult) + # The exact match prevention should work without errors + + def test_assertion_error_unexpected_sign_line_424(self): + """Test the assertion error for unexpected signs (line 424).""" + # This is tricky to test directly since it's an internal consistency check + # We'll test that normal operation doesn't trigger this assertion + from unittest.mock import MagicMock + + mock_model = MagicMock() + mock_model.encode.return_value = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) + + index = SentinelLocalIndex( + sentence_model=mock_model, + positive_embeddings=torch.tensor([[1.0, 0.0]]), + negative_embeddings=torch.tensor([[0.0, 1.0]]) + ) + + # Normal operation should not trigger the assertion + result = index.calculate_rare_class_affinity(["test text"]) + assert isinstance(result, RareClassAffinityResult) + + def test_fallback_similarity_handling_lines_457_459(self): + """Test fallback similarity handling when top_k matches are insufficient.""" + from unittest.mock import MagicMock, patch + + # Create a mock that returns limited similarity results + mock_model = MagicMock() + mock_model.encode.return_value = torch.tensor([[1.0, 0.0]]) + + # Create index with minimal embeddings + index = SentinelLocalIndex( + sentence_model=mock_model, + positive_embeddings=torch.tensor([[1.0, 0.0]]), + negative_embeddings=torch.tensor([[0.0, 1.0]]), + positive_corpus=["positive text"], + negative_corpus=["negative text"] + ) + + # Mock semantic_search to return very limited results that would require fallback + with patch('sentinel.sentinel_local_index.semantic_search') as mock_search: + # Return results with very low scores or limited matches - correct format for semantic_search + mock_search.side_effect = [ + [[{"corpus_id": 0, "score": 0.1}]], # positive matches for query 0 - low score + [[{"corpus_id": 0, "score": 0.1}]] # negative matches for query 0 - low score + ] + + result = index.calculate_rare_class_affinity( + ["test text"], + top_k=5, # Request more than available + min_score_to_consider=0.0 # Allow low scores + ) + + assert isinstance(result, RareClassAffinityResult) + + def test_empty_observation_scores_line_503(self, simple_index): + """Test the empty observation scores path (line 503).""" + # Use an extremely high threshold to ensure no scores pass + result = simple_index.calculate_rare_class_affinity( + ["any text"], + min_score_to_consider=1000.0 # Impossibly high threshold + ) + + assert isinstance(result, RareClassAffinityResult) + assert result.rare_class_affinity_score == 0.0 # Should be 0.0 when no scores pass From ff7061c91c07736b51e8ad00e6edbab6d5a56624 Mon Sep 17 00:00:00 2001 From: rafainn Date: Wed, 20 Aug 2025 19:30:13 +0100 Subject: [PATCH 17/21] Chore: removed redundant imports --- src/sentinel/embeddings/sbert.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/sentinel/embeddings/sbert.py b/src/sentinel/embeddings/sbert.py index df23ee5..b0d10dc 100644 --- a/src/sentinel/embeddings/sbert.py +++ b/src/sentinel/embeddings/sbert.py @@ -19,9 +19,7 @@ """ from typing import Callable, Optional, Tuple -import os import logging -from functools import lru_cache from sentence_transformers import SentenceTransformer From b551f1c5d2c099e1299ea9d74770cd6ca65ab75f Mon Sep 17 00:00:00 2001 From: rafainn Date: Wed, 20 Aug 2025 19:40:33 +0100 Subject: [PATCH 18/21] chore: Changed caching to be false by default, to preserve old functionality, removed redundant exports, added no-cache flag to the testing script --- examples/Example_Threshold_Script.py | 21 ++++++++++++++------- src/sentinel/sentinel_local_index.py | 4 ++-- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/examples/Example_Threshold_Script.py b/examples/Example_Threshold_Script.py index 8ac60bb..3f9d498 100644 --- a/examples/Example_Threshold_Script.py +++ b/examples/Example_Threshold_Script.py @@ -31,7 +31,7 @@ from test_data import NORMAL_SPEECH, HATE_SPEECH, SEXUAL_CONTENT import numpy as np import time -from typing import Dict, List, Tuple +from typing import Dict, List def create_user_profiles() -> Dict[str, List[str]]: @@ -110,7 +110,9 @@ def create_user_profiles() -> Dict[str, List[str]]: def test_thresholds_and_ratios(review_mode: bool = False, - results_only: bool = False): + results_only: bool = False, + no_cache: bool = False, + ): """Test different threshold and ratio combinations. Args: @@ -161,9 +163,9 @@ def test_thresholds_and_ratios(review_mode: bool = False, # Time data loading load_start = time.time() index = SentinelLocalIndex.load( - path="path/to/index", + path="path/to/local/index", negative_to_positive_ratio=ratio, - Cache_Model=True + Cache_Model=True & ~no_cache ) load_time = time.time() - load_start total_load_time += load_time @@ -719,6 +721,7 @@ def main(): # Check for flags review_mode = '--review' in sys.argv or '-r' in sys.argv results_only = '--results-only' in sys.argv or '--results' in sys.argv + no_cache = '--no-cache' in sys.argv # Ensure mutually exclusive modes if review_mode and results_only: @@ -731,12 +734,16 @@ def main(): elif results_only: print("📊 Results-only mode enabled - showing rare class affinity " "scores per user") + elif no_cache: + print("Indexes will not be cached") # Set random seed for reproducible results np.random.seed(42) test_thresholds_and_ratios(review_mode=review_mode, - results_only=results_only) + results_only=results_only, + no_cache=no_cache + ) # Final cache cleanup and summary final_cache_info = get_cache_info() @@ -751,8 +758,8 @@ def main(): if not review_mode and not results_only: print("💡 Tip: Run with --review (-r) for detailed analysis or " "--results-only for user scores only") - print("🔥 Performance: Model caching enabled - repeated runs " - "will be faster!") + if no_cache: + print("Index has not been cached between models.") if __name__ == "__main__": diff --git a/src/sentinel/sentinel_local_index.py b/src/sentinel/sentinel_local_index.py index 3140bfc..cd2b176 100644 --- a/src/sentinel/sentinel_local_index.py +++ b/src/sentinel/sentinel_local_index.py @@ -176,7 +176,7 @@ def load( aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, negative_to_positive_ratio: Optional[float] = 5.0, - Cache_Model: bool = True, + Cache_Model: bool = False, ) -> "SentinelLocalIndex": """ Load the index from a path and returns a new SentinelLocalIndex instance. @@ -523,4 +523,4 @@ def calculate_rare_class_affinity( aggregation_name=agg_name, aggregation_stats=agg_stats, explanations=explanations if explain else None, - ) + ) \ No newline at end of file From 0b5a48771f57b04961642ddeab3be3c836666678 Mon Sep 17 00:00:00 2001 From: rafainn Date: Thu, 19 Feb 2026 19:27:27 +0000 Subject: [PATCH 19/21] All requested changes made --- src/sentinel/sentinel_local_index.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/sentinel/sentinel_local_index.py b/src/sentinel/sentinel_local_index.py index cd2b176..782d336 100644 --- a/src/sentinel/sentinel_local_index.py +++ b/src/sentinel/sentinel_local_index.py @@ -176,7 +176,7 @@ def load( aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, negative_to_positive_ratio: Optional[float] = 5.0, - Cache_Model: bool = False, + cache_Model: bool = False, ) -> "SentinelLocalIndex": """ Load the index from a path and returns a new SentinelLocalIndex instance. @@ -189,7 +189,7 @@ def load( If None, preserves the original ratio from the saved index. If 5.0 (default), uses a 5:1 negative to positive ratio for optimal performance. If specified, downsamples negative examples to achieve the desired ratio. - Cache_Model: Whether to use model caching for faster subsequent loads. Default True. + cache_Model: Whether to use model caching for faster subsequent loads. Default True. Returns: A new SentinelLocalIndex instance with the loaded model and embeddings. @@ -209,7 +209,7 @@ def load( sentence_model, scale_fn = get_sentence_transformer_and_scaling_fn( model_name, - use_cache = Cache_Model + use_cache = cache_Model ) # Create a new instance with the loaded model and data @@ -307,9 +307,12 @@ def calculate_rare_class_affinity( # Margin to ignore when text is only slightly more similar to positive than negative. min_score_to_consider: float = 0.1, # Use when simulating by sampling texts from the same data indexed. - prevent_exact_match: bool = False, - encoding_additional_kwargs: Mapping[str, Any] = {}, - show_progress_bar: bool = False, + prevent_exact_match: bool = False, + encoding_additional_kwargs: Mapping[str, Any] = {}, + show_progress_bar: bool = False, + explain: bool = True, + include_neighbors: bool = True, + neighbors_limit: int = 5, ) -> RareClassAffinityResult: """Calculate rare class affinity for the given text samples in realtime. @@ -330,6 +333,9 @@ def calculate_rare_class_affinity( prevent_exact_match: Whether to skip exact matches when scoring. encoding_additional_kwargs: Additional keyword arguments for encoding. show_progress_bar: Whether to display a progress bar during encoding. + explain: Whether to include per-text explainability details. + include_neighbors: Whether to include top-neighbor records in explainability output. + neighbors_limit: Maximum number of neighbor records to include per text. Returns: RareClassAffinityResult containing both the overall affinity score and @@ -366,11 +372,6 @@ def calculate_rare_class_affinity( top_k=top_k + additional_neighbors, ) - # Explainability defaults (always on for transparency) - explain = True - include_neighbors = True - neighbors_limit = 5 - observation_scores = {} explanations = {} if explain else None From 32d644aa4895ab34114fb3564e85cae81135c5b5 Mon Sep 17 00:00:00 2001 From: rafainn Date: Thu, 19 Feb 2026 19:42:31 +0000 Subject: [PATCH 20/21] fix capitalisation issue --- src/sentinel/sentinel_local_index.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sentinel/sentinel_local_index.py b/src/sentinel/sentinel_local_index.py index 782d336..3a67000 100644 --- a/src/sentinel/sentinel_local_index.py +++ b/src/sentinel/sentinel_local_index.py @@ -176,7 +176,7 @@ def load( aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, negative_to_positive_ratio: Optional[float] = 5.0, - cache_Model: bool = False, + cache_model: bool = False, ) -> "SentinelLocalIndex": """ Load the index from a path and returns a new SentinelLocalIndex instance. @@ -189,7 +189,7 @@ def load( If None, preserves the original ratio from the saved index. If 5.0 (default), uses a 5:1 negative to positive ratio for optimal performance. If specified, downsamples negative examples to achieve the desired ratio. - cache_Model: Whether to use model caching for faster subsequent loads. Default True. + cache_model: Whether to use model caching for faster subsequent loads. Default True. Returns: A new SentinelLocalIndex instance with the loaded model and embeddings. @@ -209,7 +209,7 @@ def load( sentence_model, scale_fn = get_sentence_transformer_and_scaling_fn( model_name, - use_cache = cache_Model + use_cache = cache_model ) # Create a new instance with the loaded model and data From e8535cd960c7ec72013fc77fd9a753d0df423fd2 Mon Sep 17 00:00:00 2001 From: rafainn Date: Thu, 7 May 2026 23:20:00 +0100 Subject: [PATCH 21/21] The should fix thee build issues for CI tests as pandas ^1.0.0 didn't have support for PEP 517 builds hence swapped to ^2.0.0 which is compatable - may require further testing however didn't impact functionality of code --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 56ab3a5..d87a519 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,13 +48,14 @@ pytest-timeout = "^2.4.0" [tool.poetry.group.docs.dependencies] furo = "^2024.8.6" sphinx-copybutton = "^0.5.2" +sphinx = ">=7.0.0,<8.0.0" [tool.poetry.group.examples.dependencies] datasets = "^2.14.0" jupyter = "^1.0.0" ipykernel = "^6.0.0" matplotlib = "^3.0.0" -pandas = "^1.0.0" +pandas = "^2.0.0" beautifulsoup4 = "^4.10.0" markdownify = "^0.11.0" python-slugify = "^8.0.0"