Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 92 additions & 110 deletions pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py

Large diffs are not rendered by default.

66 changes: 50 additions & 16 deletions pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,11 @@ def token_gradients(
Computes gradients of the loss with respect to the coordinates.

Args:
model (Transformer Model):
The transformer model to be used.
input_ids (torch.Tensor):
The input sequence in the form of token ids.
input_slice (slice):
The slice of the input sequence for which gradients need to be computed.
target_slice (slice):
The slice of the input sequence to be used as targets.
loss_slice (slice):
The slice of the logits to be used for computing the loss.
model (Any): The transformer model to be used.
input_ids (torch.Tensor): The input sequence in the form of token ids.
input_slice (slice): The slice of the input sequence for which gradients need to be computed.
target_slice (slice): The slice of the input sequence to be used as targets.
loss_slice (slice): The slice of the logits to be used for computing the loss.

Returns:
torch.Tensor: The gradients of each token in the input_slice with respect to the loss.
Expand Down Expand Up @@ -72,18 +67,25 @@ def token_gradients(


class GCGAttackPrompt(AttackPrompt):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
"""GCG-specific attack prompt that computes token gradients."""

def grad(self, model: Any) -> torch.Tensor:
"""
Compute token gradients for this prompt.

Args:
model (Any): The transformer model to compute gradients with.

Returns:
torch.Tensor: Gradients with respect to control tokens.
"""
return token_gradients(
model, self.input_ids.to(model.device), self._control_slice, self._target_slice, self._loss_slice
)


class GCGPromptManager(PromptManager):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
"""GCG-specific prompt manager that implements control token sampling."""

def sample_control(
self,
Expand All @@ -93,6 +95,19 @@ def sample_control(
temp: int = 1,
allow_non_ascii: bool = True,
) -> torch.Tensor:
"""
Sample new control token candidates based on gradients.

Args:
grad (torch.Tensor): Gradient tensor for control tokens.
batch_size (int): Number of candidate controls to generate.
topk (int): Number of top gradient positions to sample from. Defaults to 256.
temp (int): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1.
allow_non_ascii (bool): Whether to allow non-ASCII tokens. Defaults to True.

Returns:
torch.Tensor: Batch of new candidate control token sequences.
"""
if not allow_non_ascii:
grad[:, self._nonascii_toks.to(grad.device)] = np.inf
top_indices = (-grad).topk(topk, dim=1).indices
Expand All @@ -109,11 +124,11 @@ def sample_control(


class GCGMultiPromptAttack(MultiPromptAttack):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
"""GCG-specific multi-prompt attack that implements the GCG optimization step."""

def step(
self,
*,
batch_size: int = 1024,
topk: int = 256,
temp: int = 1,
Expand All @@ -123,6 +138,25 @@ def step(
verbose: bool = False,
filter_cand: bool = True,
) -> tuple[str, float]:
"""
Execute one GCG optimization step.

Aggregates gradients across workers, samples candidate controls,
evaluates them, and returns the best candidate.

Args:
batch_size (int): Number of candidate controls per batch. Defaults to 1024.
topk (int): Number of top gradient positions to sample from. Defaults to 256.
temp (int): Temperature for sampling. Currently unused but kept for API compatibility. Defaults to 1.
allow_non_ascii (bool): Whether to allow non-ASCII tokens. Defaults to True.
target_weight (float): Weight for target loss. Defaults to 1.
control_weight (float): Weight for control loss. Defaults to 0.1.
verbose (bool): Whether to show progress bars. Defaults to False.
filter_cand (bool): Whether to filter invalid candidates. Defaults to True.

Returns:
tuple[str, float]: The best control string and its normalized loss.
"""
main_device = self.models[0].device
control_cands = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ control_init: "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !"
n_steps: 500
test_steps: 50
batch_size: 512
lr: 0.01
learning_rate: 0.01
topk: 256
temp: 1
filter_cand: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ control_init: "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !"
n_steps: 500
test_steps: 50
batch_size: 512
lr: 0.01
learning_rate: 0.01
topk: 256
temp: 1
filter_cand: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ control_init: "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !"
n_steps: 500
test_steps: 50
batch_size: 512
lr: 0.01
learning_rate: 0.01
topk: 256
temp: 1
filter_cand: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ control_init: "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !"
n_steps: 500
test_steps: 50
batch_size: 512
lr: 0.01
learning_rate: 0.01
topk: 256
temp: 1
filter_cand: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ control_init: "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !"
n_steps: 500
test_steps: 50
batch_size: 512
lr: 0.01
learning_rate: 0.01
topk: 256
temp: 1
filter_cand: True
Expand Down
66 changes: 60 additions & 6 deletions pyrit/auxiliary_attacks/gcg/experiments/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,59 @@
import logging
import subprocess as sp
import time
from typing import Any
from typing import Any, Optional

import mlflow

logger = logging.getLogger(__name__)

_DEFAULT_PARAM_KEYS: list[str] = [
"model_name",
"transfer",
"n_train_data",
"n_test_data",
"n_steps",
"batch_size",
]


def log_params(
*,
params: Any,
param_keys: list[str] = ["model_name", "transfer", "n_train_data", "n_test_data", "n_steps", "batch_size"],
param_keys: Optional[list[str]] = None,
) -> None:
"""
Log selected parameters to MLflow.

Args:
params (Any): A config object with a `to_dict()` method containing all parameters.
param_keys (Optional[list[str]]): Keys to extract and log. Defaults to standard GCG training keys.
"""
if param_keys is None:
param_keys = _DEFAULT_PARAM_KEYS
mlflow_params = {key: params.to_dict()[key] for key in param_keys}
mlflow.log_params(mlflow_params)


def log_train_goals(train_goals: list[str]) -> None:
def log_train_goals(*, train_goals: list[str]) -> None:
"""
Log training goals as a text artifact to MLflow.

Args:
train_goals (list[str]): The list of training goal strings to log.
"""
timestamp = time.strftime("%Y%m%d-%H%M%S")
train_goals_str = "\n".join(train_goals)
mlflow.log_text(train_goals_str, f"train_goals_{timestamp}.txt")


def get_gpu_memory() -> dict[str, int]:
"""
Query free GPU memory via nvidia-smi.

Returns:
dict[str, int]: Mapping of GPU identifiers to free memory in MiB.
"""
command = "nvidia-smi --query-gpu=memory.free --format=csv"
memory_free_info = sp.check_output(command.split()).decode("ascii").split("\n")[:-1][1:]
memory_free_values = {f"gpu{i + 1}_free_memory": int(val.split()[0]) for i, val in enumerate(memory_free_info)}
Expand All @@ -34,17 +65,40 @@ def get_gpu_memory() -> dict[str, int]:
return memory_free_values


def log_gpu_memory(step: int, synchronous: bool = False) -> None:
def log_gpu_memory(*, step: int, synchronous: bool = False) -> None:
"""
Log free GPU memory metrics to MLflow.

Args:
step (int): The current training step number.
synchronous (bool): Whether to log synchronously. Defaults to False.
"""
memory_values = get_gpu_memory()
for gpu, val in memory_values.items():
mlflow.log_metric(gpu, val, step=step, synchronous=synchronous)


def log_loss(step: int, loss: float, synchronous: bool = False) -> None:
def log_loss(*, step: int, loss: float, synchronous: bool = False) -> None:
"""
Log training loss to MLflow.

Args:
step (int): The current training step number.
loss (float): The loss value to log.
synchronous (bool): Whether to log synchronously. Defaults to False.
"""
mlflow.log_metric("loss", loss, step=step, synchronous=synchronous)


def log_table_summary(losses: list[float], controls: list[str], n_steps: int) -> None:
def log_table_summary(*, losses: list[float], controls: list[str], n_steps: int) -> None:
"""
Log a summary table of losses and controls to MLflow.

Args:
losses (list[float]): Loss values for each step.
controls (list[str]): Control strings for each step.
n_steps (int): Total number of steps.
"""
timestamp = time.strftime("%Y%m%d-%H%M%S")
mlflow.log_table(
{
Expand Down
43 changes: 29 additions & 14 deletions pyrit/auxiliary_attacks/gcg/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,47 @@
from typing import Any, Dict, Union

import yaml
from train import GreedyCoordinateGradientAdversarialSuffixGenerator

from pyrit.auxiliary_attacks.gcg.experiments.train import GreedyCoordinateGradientAdversarialSuffixGenerator
from pyrit.setup.initialization import _load_environment_files

_MODEL_NAMES: list[str] = ["mistral", "llama_2", "llama_3", "vicuna", "phi_3_mini"]
_ALL_MODELS: str = "all_models"


def _load_yaml_to_dict(config_path: str) -> dict[str, Any]:
"""
Load a YAML config file and return its contents as a dictionary.

Args:
config_path (str): Path to the YAML configuration file.

Returns:
dict[str, Any]: The parsed configuration dictionary.
"""
with open(config_path, "r") as f:
data: dict[str, Any] = yaml.safe_load(f)
return data


MODEL_NAMES = ["mistral", "llama_2", "llama_3", "vicuna", "phi_3_mini"]
ALL_MODELS = "all_models"
MODEL_PARAM_OPTIONS = MODEL_NAMES + [ALL_MODELS]


def run_trainer(*, model_name: str, setup: str = "single", **extra_config_parameters: Any) -> None:
"""
Trains and generates adversarial suffix - single model single prompt.

Args:
model_name (str): The name of the model, currently supports:
"mistral", "llama_2", "llama_3", "vicuna", "phi_3_mini", "all_models"
setup (str): Identifier for the setup, currently supporst
setup (str): Identifier for the setup, currently supports
- "single": one prompt one model
- "multiple": multiple prompts one model or multiple prompts multiple models
**extra_config_parameters: Additional parameters to override config values.

Raises:
ValueError: If model_name is not supported or HUGGINGFACE_TOKEN is not set.
"""
if model_name not in MODEL_NAMES:
raise ValueError(
"Model name not supported. Currently supports 'mistral', 'llama_2', 'llama_3', 'vicuna', and 'phi_3_mini'"
)
if model_name not in _MODEL_NAMES and model_name != _ALL_MODELS:
supported_models: str = "', '".join(_MODEL_NAMES + [_ALL_MODELS])
raise ValueError(f"Model name not supported. Currently supports '{supported_models}'")

_load_environment_files(env_files=None)
hf_token = os.environ.get("HUGGINGFACE_TOKEN")
Expand Down Expand Up @@ -70,7 +79,13 @@ def run_trainer(*, model_name: str, setup: str = "single", **extra_config_parame
trainer.generate_suffix(**config)


def parse_arguments() -> argparse.Namespace:
def _parse_arguments() -> argparse.Namespace:
"""
Parse command-line arguments for the adversarial suffix trainer.

Returns:
argparse.Namespace: Parsed arguments.
"""
parser = argparse.ArgumentParser(description="Script to run the adversarial suffix trainer")
parser.add_argument("--model_name", type=str, help="The name of the model")
parser.add_argument(
Expand All @@ -90,10 +105,10 @@ def parse_arguments() -> argparse.Namespace:


if __name__ == "__main__":
args = parse_arguments()
args = _parse_arguments()
run_trainer(
model_name=args.model_name,
num_train_models=len(MODEL_NAMES) if args.model_name == ALL_MODELS else 1,
num_train_models=len(_MODEL_NAMES) if args.model_name == _ALL_MODELS else 1,
setup=args.setup,
n_train_data=args.n_train_data,
n_test_data=args.n_test_data,
Expand Down
Loading