Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions scripts/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,32 +84,32 @@ def test_standard_setup(self) -> None:
'1\tbar\tfoo\n'
'-1\tbaz\tqux\n'))
# The input matrix X and the target vector Y should look like below now:
# Y X(foo bar baz BIAS)
# 1 1 1 0 1
# -1 1 0 0 1
# 1 1 1 1 1
# 1 1 1 0 1
# -1 0 0 1 1
# Y X(foo bar baz)
# 1 1 1 0
# -1 1 0 0
# 1 1 1 1
# 1 1 1 0
# -1 0 0 1
rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 1)
self.assertEqual(features, ['foo', 'bar', 'baz'])
self.assertEqual(Y.tolist(), [True, False, True, True, False])
self.assertEqual(rows.tolist(), [0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4])
self.assertEqual(cols.tolist(), [0, 1, 3, 0, 3, 0, 1, 2, 3, 1, 0, 3, 2, 3])
self.assertEqual(rows.tolist(), [0, 0, 1, 2, 2, 2, 3, 3, 4])
self.assertEqual(cols.tolist(), [0, 1, 0, 0, 1, 2, 1, 0, 2])

def test_skip_invalid_rows(self) -> None:
with open(self.ENTRIES_FILE_PATH, 'w') as f:
f.write(('\n1\tfoo\tbar\n'
'-1\n\n'
'-1\tfoo\n\n'))
# The input matrix X and the target vector Y should look like below now:
# Y X(foo bar BIAS)
# 1 1 1 1
# -1 1 0 1
# Y X(foo bar)
# 1 1 1
# -1 1 0
rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 0)
self.assertEqual(features, ['foo', 'bar'])
self.assertEqual(Y.tolist(), [True, False])
self.assertEqual(rows.tolist(), [0, 0, 0, 1, 1])
self.assertEqual(cols.tolist(), [0, 1, 2, 0, 2])
self.assertEqual(rows.tolist(), [0, 0, 1])
self.assertEqual(cols.tolist(), [0, 1, 0])

def tearDown(self) -> None:
if (os.path.exists(self.ENTRIES_FILE_PATH)):
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_standard_setup(self) -> None:
self.assertEqual(result.fscore, 2 * p * r / (p + r))


class TestUpdateWeights(unittest.TestCase):
class TestUpdate(unittest.TestCase):
X = np.array([
[1, 0, 1, 0],
[0, 1, 0, 0],
Expand All @@ -194,8 +194,8 @@ def test_standard_setup1(self) -> None:
Y = np.array([1, 1, 0, 0, 1], dtype=bool)
w = np.array([0.1, 0.3, 0.1, 0.1, 0.4])
scores = jnp.zeros(M)
new_w, new_scores, best_feature_index, added_score = train.update_weights(
w, rows, cols, Y, scores, M)
new_w, new_scores, best_feature_index, added_score = train.update(
w, scores, rows, cols, Y)
self.assertFalse(w.argmax() == 0)
self.assertTrue(new_w.argmax() == 0)
self.assertFalse(scores.argmax() == 1)
Expand Down Expand Up @@ -254,9 +254,8 @@ def test_fit(self) -> None:
for weight in weights:
model.setdefault(weight[0], 0)
model[weight[0]] += float(weight[1])
self.assertEqual(scores.shape[0], len(features) + 1)
loaded_scores = [model.get(feature, 0) for feature in features
] + [model.get('BIAS', 0)]
self.assertEqual(scores.shape[0], len(features))
loaded_scores = [model.get(feature, 0) for feature in features]
self.assertTrue(np.all(np.isclose(scores, loaded_scores)))

def tearDown(self) -> None:
Expand Down
58 changes: 27 additions & 31 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def preprocess(
- features (List[str]): The list of features.
"""
features_counter: typing.Counter[str] = Counter()
N = 0
X = []
Y = array.array('B')
with open(entries_filename) as f:
Expand All @@ -77,7 +76,6 @@ def preprocess(
Y.append(cols[0] == '1')
X.append(cols[1:])
features_counter.update(cols[1:])
N += 1
features = [
item[0]
for item in features_counter.most_common()
Expand All @@ -90,19 +88,17 @@ def preprocess(
hit_indices = [feature_index[feat] for feat in x if feat in feature_index]
rows.extend(i for _ in range(len(hit_indices)))
cols.extend(hit_indices) # type: ignore
rows.append(i)
cols.append(len(features)) # type: ignore
return jnp.asarray(rows), jnp.asarray(cols), jnp.asarray(
Y, dtype=bool), features


def split_data(
rows: npt.NDArray[np.int64],
cols: npt.NDArray[np.int64],
rows: npt.NDArray[np.int32],
cols: npt.NDArray[np.int32],
Y: npt.NDArray[np.bool_],
split_ratio: float = .9
) -> typing.Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64],
npt.NDArray[np.int64], npt.NDArray[np.int64],
) -> typing.Tuple[npt.NDArray[np.int32], npt.NDArray[np.int32],
npt.NDArray[np.int32], npt.NDArray[np.int32],
npt.NDArray[np.bool_], npt.NDArray[np.bool_]]:
"""Splits a dataset into a training dataset and a test dataset.

Expand All @@ -129,23 +125,23 @@ def split_data(


@partial(jax.jit, static_argnums=[3])
def pred(phis: npt.NDArray[np.float64], rows: npt.NDArray[np.int64],
cols: npt.NDArray[np.int64], N: int) -> npt.NDArray[np.bool_]:
def pred(scores: npt.NDArray[np.float32], rows: npt.NDArray[np.int32],
cols: npt.NDArray[np.int32], N: int) -> npt.NDArray[np.bool_]:
"""Predicts the target output from the learned scores and input entries.

Args:
phis (numpy.ndarray): Contribution scores of features.
scores (numpy.ndarray): Contribution scores of features.
rows (numpy.ndarray): Row indices of True values in the input.
cols (numpy.ndarray): Column indices of True values in the input.
N (int): The number of input entries.

Returns:
res (numpy.ndarray): A prediction of the target.
"""
# This is equivalent to phis.dot(2X - 1) = 2phis.dot(X) - phis.sum() but in a
# sparse matrix-friendly way.
r: npt.NDArray[np.float64] = 2 * jax.ops.segment_sum(
phis.take(cols), rows, N) - phis.sum()
# This is equivalent to scores.dot(2X - 1) = 2 * scores.dot(X) - scores.sum()
# but in a sparse matrix-friendly way.
r: npt.NDArray[np.float32] = 2 * jax.ops.segment_sum(
scores.take(cols), rows, N) - scores.sum()
return r > 0


Expand Down Expand Up @@ -180,20 +176,20 @@ def get_metrics(pred: npt.NDArray[np.bool_],
)


@partial(jax.jit, static_argnums=[5])
def update_weights(w: npt.NDArray[np.float64], rows: npt.NDArray[np.int64],
cols: npt.NDArray[np.int64], Y: npt.NDArray[np.bool_],
scores: typing.Any,
M: int) -> typing.Tuple[typing.Any, typing.Any, int, float]:
"""Calculates the new weight vector from the best feature and its score.
@jax.jit
def update(
w: npt.NDArray[np.float32], scores: typing.Any, rows: npt.NDArray[np.int32],
cols: npt.NDArray[np.int32], Y: npt.NDArray[np.bool_]
) -> typing.Tuple[typing.Any, typing.Any, int, float]:
"""Calculates the new weight vector and the contribution scores.

Args:
w (numpy.ndarray): A weight vector.
scores (JAX array): Contribution scores of features.
rows (numpy.ndarray): Row indices of True values in the input data.
cols (numpy.ndarray): Column indices of True values in the input data.
Y (numpy.ndarray): The target output.
scores (JAX array): Contribution scores of features.
M (int): The number of columns in the input data.


Returns:
A tuple of following items:
Expand All @@ -202,6 +198,8 @@ def update_weights(w: npt.NDArray[np.float64], rows: npt.NDArray[np.int64],
- best_feature_index (int): The index of the best feature.
- score (float): The newly added score for the best feature.
"""
N = w.shape[0]
M = scores.shape[0]
# This is quivalent to w.dot(Y[:, None] ^ X). Note that y ^ x = y + x - 2yx,
# hence w.dot(y ^ x) = w.dot(y) - w(2y - 1).dot(x).
# `segment_sum` is used to implement sparse matrix-friendly dot products.
Expand All @@ -211,7 +209,6 @@ def update_weights(w: npt.NDArray[np.float64], rows: npt.NDArray[np.int64],
positivity: bool = res.at[best_feature_index].get() < 0.5
err_min = err.at[best_feature_index].get()
amount: float = jnp.log((1 - err_min) / (err_min + EPS))
N = Y.shape[0]

# This is equivalent to X_best = X[:, best_feature_index]
X_best = jnp.zeros(
Expand All @@ -224,8 +221,8 @@ def update_weights(w: npt.NDArray[np.float64], rows: npt.NDArray[np.int64],
return w, scores, best_feature_index, score


def fit(rows_train: npt.NDArray[np.int64], cols_train: npt.NDArray[np.int64],
rows_test: npt.NDArray[np.int64], cols_test: npt.NDArray[np.int64],
def fit(rows_train: npt.NDArray[np.int32], cols_train: npt.NDArray[np.int32],
rows_test: npt.NDArray[np.int32], cols_test: npt.NDArray[np.int32],
Y_train: npt.NDArray[np.bool_], Y_test: npt.NDArray[np.bool_],
features: typing.List[str], iters: int, weights_filename: str,
log_filename: str, out_span: int) -> typing.Any:
Expand Down Expand Up @@ -255,7 +252,7 @@ def fit(rows_train: npt.NDArray[np.int64], cols_train: npt.NDArray[np.int64],
'test_accuracy\ttest_precision\ttest_recall\ttest_fscore\n')
print('Outputting learned weights to %s ...' % (weights_filename))

M = len(features) + 1
M = len(features)
scores = jnp.zeros(M)
feature_score_buffer: typing.List[typing.Tuple[str, float]] = []
N_train = Y_train.shape[0]
Expand Down Expand Up @@ -296,11 +293,10 @@ def output_progress(t: int) -> None:
))

for t in range(iters):
w, scores, best_feature_index, score = update_weights(
w, rows_train, cols_train, Y_train, scores, M)
w, scores, best_feature_index, score = update(w, scores, rows_train,
cols_train, Y_train)
w.block_until_ready()
feature = features[best_feature_index] if (
best_feature_index < len(features)) else 'BIAS'
feature = features[best_feature_index]
feature_score_buffer.append((feature, score))
if (t + 1) % out_span == 0:
output_progress(t + 1)
Expand Down