Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 39 additions & 24 deletions algorithms/linfa-logistic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,32 +301,33 @@ where
{
let y = y.as_single_targets();

// counts the instances of two distinct class labels
let mut binary_classes = [None, None];
// find binary classes of our target dataset
for class in y {
binary_classes = match binary_classes {
// count the first class label
[None, None] => [Some((class, 1)), None],
// if the class has already been counted, increment the count
[Some((c, count)), c2] if c == class => [Some((class, count + 1)), c2],
[c1, Some((c, count))] if c == class => [c1, Some((class, count + 1))],
// count the second class label
[Some(c1), None] => [Some(c1), Some((class, 1))],

// should not be possible
[None, Some(_)] => unreachable!("impossible binary class array"),
// found 3rd distinct class
[Some(_), Some(_)] => return Err(Error::TooManyClasses),
};
}

let (pos_class, neg_class) = match binary_classes {
let (class_a, class_b) = match binary_classes {
[Some(a), Some(b)] => (a, b),
_ => return Err(Error::TooFewClasses),
};

let mut target_array = y
// Sort by label value (Ord), not by encounter order or count.
// The smaller label is always negative (-1),
// the larger label is always positive (+1).
let (neg_class, pos_class) = if class_a.0 < class_b.0 {
(class_a, class_b)
} else {
(class_b, class_a)
};

let target_array = y
.into_iter()
.map(|x| {
if x == pos_class.0 {
Expand All @@ -337,24 +338,14 @@ where
})
.collect::<Array1<_>>();

let (pos_cl, neg_cl) = if pos_class.1 < neg_class.1 {
// If we found the larger class first, flip the sign in the target
// vector, so that -1.0 is always the label for the smaller class
// and 1.0 the label for the larger class
target_array *= -F::one();
(neg_class.0.clone(), pos_class.0.clone())
} else {
(pos_class.0.clone(), neg_class.0.clone())
};

Ok((
BinaryClassLabels {
pos: ClassLabel {
class: pos_cl,
class: pos_class.0.clone(),
label: F::POSITIVE_LABEL,
},
neg: ClassLabel {
class: neg_cl,
class: neg_class.0.clone(),
label: F::NEGATIVE_LABEL,
},
},
Expand Down Expand Up @@ -989,7 +980,7 @@ mod test {
let dataset = Dataset::new(x, y);
let res = log_reg.fit(&dataset).unwrap();
assert_abs_diff_eq!(res.intercept(), 0.0);
assert!(res.params().abs_diff_eq(&array![-0.681], 1e-3));
assert!(res.params().abs_diff_eq(&array![0.681], 1e-3));
assert_eq!(
&res.predict(dataset.records()),
dataset.targets().as_single_targets()
Expand Down Expand Up @@ -1172,7 +1163,7 @@ mod test {
let dataset = Dataset::new(x, y);
let res = log_reg.fit(&dataset).unwrap();
assert_abs_diff_eq!(res.intercept(), 0.0_f32);
assert!(res.params().abs_diff_eq(&array![-0.682_f32], 1e-3));
assert!(res.params().abs_diff_eq(&array![0.682_f32], 1e-3));
assert_eq!(
&res.predict(dataset.records()),
dataset.targets().as_single_targets()
Expand Down Expand Up @@ -1375,4 +1366,28 @@ mod test {
}
));
}

#[test]
fn label_order_independent() {
let x1 = array![[-1.0], [1.0], [-0.5], [0.5]];
let y1 = array!["cat", "dog", "cat", "dog"];

let x2 = array![[1.0], [-1.0], [0.5], [-0.5]];
let y2 = array!["dog", "cat", "dog", "cat"];

let model1 = LogisticRegression::default()
.fit(&Dataset::new(x1, y1))
.unwrap();
let model2 = LogisticRegression::default()
.fit(&Dataset::new(x2, y2))
.unwrap();

assert_eq!(model1.labels().pos.class, "dog");
assert_eq!(model1.labels().neg.class, "cat");
assert_eq!(model2.labels().pos.class, "dog");
assert_eq!(model2.labels().neg.class, "cat");

assert_abs_diff_eq!(model1.intercept(), model2.intercept());
assert!(model1.params().abs_diff_eq(model2.params(), 1e-6));
}
}
Loading