Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions scripts/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import typing
import unittest

import numpy as np
from jax import numpy as jnp

# module hack
Expand Down Expand Up @@ -138,15 +137,15 @@ def teset_no_val(self) -> None:
class TestPred(unittest.TestCase):

def test_standard_setup(self) -> None:
X = np.array([
X = jnp.array([
[1, 1, 0],
[1, 0, 1],
[0, 1, 0],
[0, 0, 1],
])
phis = np.array([0.4, 0.2, -0.3])
phis = jnp.array([0.4, 0.2, -0.3])
N = X.shape[0]
rows, cols = np.where(X == 1)
rows, cols = jnp.where(X == 1)
res = train.pred(phis, rows, cols, N)
expected = [
0.4 + 0.2 - (-0.3) > 0,
Expand All @@ -160,8 +159,8 @@ def test_standard_setup(self) -> None:
class TestGetMetrics(unittest.TestCase):

def test_standard_setup(self) -> None:
pred = np.array([0, 0, 1, 0, 0], dtype=bool)
target = np.array([1, 0, 1, 1, 1], dtype=bool)
pred = jnp.array([0, 0, 1, 0, 0], dtype=bool)
target = jnp.array([1, 0, 1, 1, 1], dtype=bool)
result = train.get_metrics(pred, target)
self.assertEqual(result.tp, 1)
self.assertEqual(result.tn, 1)
Expand All @@ -176,7 +175,7 @@ def test_standard_setup(self) -> None:


class TestUpdate(unittest.TestCase):
X = np.array([
X = jnp.array([
[1, 0, 1, 0],
[0, 1, 0, 0],
[0, 0, 0, 0],
Expand All @@ -185,10 +184,10 @@ class TestUpdate(unittest.TestCase):
])

def test_standard_setup1(self) -> None:
rows, cols = np.where(self.X == 1)
rows, cols = jnp.where(self.X == 1)
M = self.X.shape[-1]
Y = np.array([1, 1, 0, 0, 1], dtype=bool)
w = np.array([0.1, 0.3, 0.1, 0.1, 0.4])
Y = jnp.array([1, 1, 0, 0, 1], dtype=bool)
w = jnp.array([0.1, 0.3, 0.1, 0.1, 0.4])
scores = jnp.zeros(M)
new_w, new_scores, best_feature_index, added_score = train.update(
w, scores, rows, cols, Y)
Expand Down Expand Up @@ -250,8 +249,8 @@ def test_fit(self) -> None:
model.setdefault(weight[0], 0)
model[weight[0]] += float(weight[1])
self.assertEqual(scores.shape[0], len(features))
loaded_scores = [model.get(feature, 0) for feature in features]
self.assertTrue(np.all(np.isclose(scores, loaded_scores)))
loaded_scores = jnp.array([model.get(feature, 0) for feature in features])
self.assertTrue(jnp.all(jnp.isclose(scores, loaded_scores)))
os.remove(weights_file_path)
os.remove(log_file_path)

Expand Down
47 changes: 21 additions & 26 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@

import jax
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt

EPS = np.finfo(float).eps # type: np.floating[typing.Any]
EPS: float = jnp.finfo(float).eps
DEFAULT_OUTPUT_NAME = 'weights.txt'
DEFAULT_LOG_NAME = 'train.log'
DEFAULT_FEATURE_THRES = 10
Expand Down Expand Up @@ -140,34 +138,33 @@ def preprocess(


@partial(jax.jit, static_argnums=[3])
def pred(scores: npt.NDArray[np.float32], rows: npt.NDArray[np.int32],
cols: npt.NDArray[np.int32], N: int) -> npt.NDArray[np.bool_]:
def pred(scores: jax.Array, rows: jax.Array, cols: jax.Array,
N: int) -> jax.Array:
"""Predicts the target output from the learned scores and input entries.

Args:
scores (numpy.ndarray): Contribution scores of features.
rows (numpy.ndarray): Row indices of True values in the input.
cols (numpy.ndarray): Column indices of True values in the input.
scores (jax.Array): Contribution scores of features.
rows (jax.Array): Row indices of True values in the input.
cols (jax.Array): Column indices of True values in the input.
N (int): The number of input entries.

Returns:
res (numpy.ndarray): A prediction of the target.
res (jax.Array): A prediction of the target.
"""
# This is equivalent to scores.dot(2X - 1) = 2 * scores.dot(X) - scores.sum()
# but in a sparse matrix-friendly way.
r: npt.NDArray[np.float32] = 2 * jax.ops.segment_sum(
scores.take(cols), rows, N) - scores.sum()
r: jax.Array = 2 * jax.ops.segment_sum(scores.take(cols), rows,
N) - scores.sum()
return r > 0


@jax.jit
def get_metrics(pred: npt.NDArray[np.bool_],
actual: npt.NDArray[np.bool_]) -> Result:
def get_metrics(pred: jax.Array, actual: jax.Array) -> Result:
"""Gets evaluation metrics from the prediction and the actual target.

Args:
pred (numpy.ndarray): A prediction of the target.
actual (numpy.ndarray): The actual target.
pred (jax.Array): A prediction of the target.
actual (jax.Array): The actual target.

Returns:
result (Result): A result.
Expand All @@ -192,23 +189,21 @@ def get_metrics(pred: npt.NDArray[np.bool_],


@jax.jit
def update(
w: npt.NDArray[np.float32], scores: typing.Any, rows: npt.NDArray[np.int32],
cols: npt.NDArray[np.int32], Y: npt.NDArray[np.bool_]
) -> typing.Tuple[typing.Any, typing.Any, int, float]:
def update(w: jax.Array, scores: jax.Array, rows: jax.Array, cols: jax.Array,
Y: jax.Array) -> typing.Tuple[jax.Array, jax.Array, int, float]:
"""Calculates the new weight vector and the contribution scores.

Args:
w (numpy.ndarray): A weight vector.
w (jax.Array): A weight vector.
scores (JAX array): Contribution scores of features.
rows (numpy.ndarray): Row indices of True values in the input data.
cols (numpy.ndarray): Column indices of True values in the input data.
Y (numpy.ndarray): The target output.
rows (jax.Array): Row indices of True values in the input data.
cols (jax.Array): Column indices of True values in the input data.
Y (jax.Array): The target output.


Returns:
A tuple of following items:
- w (numpy.ndarray): The new weight vector.
- w (jax.Array): The new weight vector.
- scores (JAX array): The new contribution scores.
- best_feature_index (int): The index of the best feature.
- score (float): The newly added score for the best feature.
Expand Down Expand Up @@ -238,7 +233,7 @@ def update(

def fit(dataset_train: Dataset, dataset_val: typing.Optional[Dataset],
features: typing.List[str], iters: int, weights_filename: str,
log_filename: str, out_span: int) -> typing.Any:
log_filename: str, out_span: int) -> jax.Array:
"""Trains an AdaBoost binary classifier.

Args:
Expand All @@ -251,7 +246,7 @@ def fit(dataset_train: Dataset, dataset_val: typing.Optional[Dataset],
out_span (int): Iteration span to output metics and weights.

Returns:
scores (Any): The contribution scores.
scores (jax.Array): The contribution scores.
"""
with open(weights_filename, 'w') as f:
f.write('')
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ dev =
build
flake8
isort
numpy
mypy
pytest
toml
Expand Down