Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions scripts/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import math
import os
import sys
import tempfile
import typing
import unittest

Expand Down Expand Up @@ -73,11 +74,10 @@ def test_cmdargs_full(self) -> None:


class TestPreprocess(unittest.TestCase):
ENTRIES_FILE_PATH = os.path.abspath(
os.path.join(os.path.dirname(__file__), 'entries_test.txt'))

def test_standard_setup(self) -> None:
with open(self.ENTRIES_FILE_PATH, 'w') as f:
entries_file_path = tempfile.NamedTemporaryFile().name
with open(entries_file_path, 'w') as f:
f.write(('1\tfoo\tbar\n'
'-1\tfoo\n'
'1\tfoo\tbar\tbaz\n'
Expand All @@ -90,30 +90,29 @@ def test_standard_setup(self) -> None:
# 1 1 1 1
# 1 1 1 0
# -1 0 0 1
rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 1)
rows, cols, Y, features = train.preprocess(entries_file_path, 1)
self.assertEqual(features, ['foo', 'bar', 'baz'])
self.assertEqual(Y.tolist(), [True, False, True, True, False])
self.assertEqual(rows.tolist(), [0, 0, 1, 2, 2, 2, 3, 3, 4])
self.assertEqual(cols.tolist(), [0, 1, 0, 0, 1, 2, 1, 0, 2])
os.remove(entries_file_path)

def test_skip_invalid_rows(self) -> None:
with open(self.ENTRIES_FILE_PATH, 'w') as f:
entries_file_path = tempfile.NamedTemporaryFile().name
with open(entries_file_path, 'w') as f:
f.write(('\n1\tfoo\tbar\n'
'-1\n\n'
'-1\tfoo\n\n'))
# The input matrix X and the target vector Y should look like below now:
# Y X(foo bar)
# 1 1 1
# -1 1 0
rows, cols, Y, features = train.preprocess(self.ENTRIES_FILE_PATH, 0)
rows, cols, Y, features = train.preprocess(entries_file_path, 0)
self.assertEqual(features, ['foo', 'bar'])
self.assertEqual(Y.tolist(), [True, False])
self.assertEqual(rows.tolist(), [0, 0, 1])
self.assertEqual(cols.tolist(), [0, 1, 0])

def tearDown(self) -> None:
if (os.path.exists(self.ENTRIES_FILE_PATH)):
os.remove(self.ENTRIES_FILE_PATH)
os.remove(entries_file_path)


class TestSplitData(unittest.TestCase):
Expand Down Expand Up @@ -205,12 +204,10 @@ def test_standard_setup1(self) -> None:


class TestFit(unittest.TestCase):
WEIGHTS_FILE_PATH = os.path.abspath(
os.path.join(os.path.dirname(__file__), 'weights_test.txt'))
LOG_FILE_PATH = os.path.abspath(
os.path.join(os.path.dirname(__file__), 'train_test.log'))

def test_fit(self) -> None:
weights_file_path = tempfile.NamedTemporaryFile().name
log_file_path = tempfile.NamedTemporaryFile().name
# Prepare a dataset that the 2nd feature (= the 2nd col in X) perfectly
# correlates with Y in a negative way.
X = np.array([
Expand All @@ -225,8 +222,8 @@ def test_fit(self) -> None:
iters = 5
out_span = 2
scores = train.fit(rows, cols, rows, cols, Y, Y, features, iters,
self.WEIGHTS_FILE_PATH, self.LOG_FILE_PATH, out_span)
with open(self.WEIGHTS_FILE_PATH) as f:
weights_file_path, log_file_path, out_span)
with open(weights_file_path) as f:
weights = [
line.split('\t') for line in f.read().splitlines() if line.strip()
]
Expand All @@ -238,7 +235,7 @@ def test_fit(self) -> None:
iters,
msg='The number of lines should equal to the iteration count.')

with open(self.LOG_FILE_PATH) as f:
with open(log_file_path) as f:
log = [line.split('\t') for line in f.read().splitlines() if line.strip()]
self.assertEqual(
len(log),
Expand All @@ -257,10 +254,8 @@ def test_fit(self) -> None:
self.assertEqual(scores.shape[0], len(features))
loaded_scores = [model.get(feature, 0) for feature in features]
self.assertTrue(np.all(np.isclose(scores, loaded_scores)))

def tearDown(self) -> None:
os.remove(self.WEIGHTS_FILE_PATH)
os.remove(self.LOG_FILE_PATH)
os.remove(weights_file_path)
os.remove(log_file_path)


if __name__ == '__main__':
Expand Down