Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def _calculate_label_metrics(self, metric: str, category: str):
[sample["output"] for sample in self.data[category]["data"]]

flag = False
softmaxs = []
logits = []
for i, sample in enumerate(self.data[category]["data"]):
if np.any(np.isnan(np.array(list(sample["softmax_over_choices"].values())))):
if np.any(np.isnan(np.array(list(sample["logits_over_choices"].values())))):
if not flag:
print(
f"NaN in the softmax, switch to exact match for category {category} in dataset {self.dataset_name} in model {self.model_name}."
f"NaN in the logits, switch to exact match for category {category} in dataset {self.dataset_name} in model {self.model_name}."
)
flag = True
score = 0
Expand All @@ -79,13 +79,13 @@ def _calculate_label_metrics(self, metric: str, category: str):
score,
metric_helper.accuracy_by_options(sample["input"], sample["output"], ref),
)
softmaxs.append(references[i] if score == 1 else -1)
logits.append(references[i] if score == 1 else -1)
else:
softmaxs.append(np.argmax(np.array(list(sample["softmax_over_choices"].values()))))
logits.append(np.argmax(np.array(list(sample["logits_over_choices"].values()))))

references = np.array(references)
softmaxs = np.array(softmaxs)
scores = np.sum(references == softmaxs) / len(self.data[category]["data"]) * 100
logits = np.array(logits)
scores = np.sum(references == logits) / len(self.data[category]["data"]) * 100

self.evaluation_results[metric][category] = (scores, len(self.data[category]["data"]))
self.evaluation_results[metric]["ALL"] += scores * weight
Expand All @@ -105,12 +105,12 @@ def _calculate_combined_metrics(self, metric: str, category: str):
predictions = [sample["output"] for sample in self.data[category]["data"]]

flag = False
softmaxs = []
logits = []
for i, sample in enumerate(self.data[category]["data"]):
if np.any(np.isnan(np.array(list(sample["softmax_over_choices"].values())))):
if np.any(np.isnan(np.array(list(sample["logits_over_choices"].values())))):
if not flag:
print(
f"NaN in the softmax, switch to exact match for category {category} in dataset {self.dataset_name} in model {self.model_name}."
f"NaN in the logits, switch to exact match for category {category} in dataset {self.dataset_name} in model {self.model_name}."
)
flag = True
score = 0
Expand All @@ -121,16 +121,14 @@ def _calculate_combined_metrics(self, metric: str, category: str):
sample["output"], ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]
),
)
softmaxs.append(references[i] if score == 1 else -1)
logits.append(references[i] if score == 1 else -1)
else:
softmaxs.append(np.argmax(np.array(list(sample["softmax_over_choices"].values()))))
logits.append(np.argmax(np.array(list(sample["logits_over_choices"].values()))))

metric_method = eval("metric_helper." + metric)

total_score = 0.0
for prediction, reference, references_label, softmax in zip(
predictions, references, references_labels, softmaxs
):
for prediction, reference, references_label, softmax in zip(predictions, references, references_labels, logits):
score = 0.0

for ref in reference:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def _load_model(
shard_config: Shard config for tensor parallel.

"""
model_kwargs.setdefault("torch_dtype", torch.float16)

if "torch_dtype" in model_kwargs:
model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
else:
model_kwargs.setdefault("torch_dtype", torch.float16)

if "config" in model_kwargs:
model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"])
Expand Down Expand Up @@ -586,11 +586,10 @@ def _load_model(
shard_config: Shard config for tensor parallel.

"""

model_kwargs.setdefault("torch_dtype", torch.float16)

if "torch_dtype" in model_kwargs:
model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
else:
model_kwargs.setdefault("torch_dtype", torch.float16)

if "config" in model_kwargs:
model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"])
Expand Down