Skip to content
56 changes: 55 additions & 1 deletion tests/test_sdk.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
import unittest.mock as mock
from unittest.mock import patch
from datetime import datetime, timedelta, date
from dateutil.parser import parse
from pytz import timezone, UTC
Expand Down Expand Up @@ -47,13 +48,27 @@ def raise_for_status(self):

class TestWattTimeBase(unittest.TestCase):
def setUp(self):
self.base = WattTimeBase()
"""Create both single-threaded and multi-threaded instances."""
self.base = WattTimeBase(multithreaded=False, rate_limit=2)
self.base_mt = WattTimeBase(multithreaded=True, rate_limit=2)

def test_login_with_real_api(self):
self.base._login()
assert self.base.token is not None
assert self.base.token_valid_until > datetime.now()

@patch("time.sleep", return_value=None)
def test_apply_rate_limit(self, mock_sleep):
"""Test _apply_rate_limit (single-threaded) triggers sleep when rate limit is exceeded."""
# Set up a scenario with more requests than allowed:
# Our rate_limit is 2, so with 3 prior requests all within the past second, we expect a wait.
self.base._last_request_times = [0, 0.2, 0.3]
# Use a current timestamp such that all three are within the last 1 second.
ts = 0.5
self.base._apply_rate_limit(ts)
# Expect sleep to be called with: wait_time = 1.0 - (ts - earliest_timestamp) = 1.0 - (0.5 - 0) = 0.5 seconds.
mock_sleep.assert_called_with(0.5)

def test_parse_dates_with_string(self):
start = "2022-01-01"
end = "2022-01-31"
Expand Down Expand Up @@ -125,10 +140,23 @@ def test_mock_register(self, mock_post):
resp = self.base.register(email=os.getenv("WATTTIME_EMAIL"))
self.assertEqual(len(mock_post.call_args_list), 1)

def test_get_password(self):

with mock.patch.dict(os.environ, {}, clear=True), self.assertRaises(ValueError):
wt_base = WattTimeBase()

with mock.patch.dict(os.environ, {}, clear=True):
wt_base = WattTimeBase(
username="WATTTIME_USERNAME", password="WATTTIME_PASSWORD"
)
self.assertEqual(wt_base.username, "WATTTIME_USERNAME")
self.assertEqual(wt_base.password, "WATTTIME_PASSWORD")


class TestWattTimeHistorical(unittest.TestCase):
def setUp(self):
self.historical = WattTimeHistorical()
self.historical_mt = WattTimeHistorical(multithreaded=True)

def test_get_historical_jsons_3_months(self):
start = "2022-01-01 00:00Z"
Expand All @@ -139,6 +167,15 @@ def test_get_historical_jsons_3_months(self):
self.assertGreaterEqual(len(jsons), 1)
self.assertIsInstance(jsons[0], dict)

def test_get_historical_jsons_3_months_multithreaded(self):
start = "2022-01-01 00:00Z"
end = "2022-12-31 00:00Z"
jsons = self.historical_mt.get_historical_jsons(start, end, REGION)

self.assertIsInstance(jsons, list)
self.assertGreaterEqual(len(jsons), 1)
self.assertIsInstance(jsons[0], dict)

def test_get_historical_jsons_1_week(self):
start = "2022-01-01 00:00Z"
end = "2022-01-07 00:00Z"
Expand Down Expand Up @@ -405,6 +442,7 @@ def test_horizon_hours(self):
class TestWattTimeMaps(unittest.TestCase):
def setUp(self):
self.maps = WattTimeMaps()
self.myaccess = WattTimeMyAccess()

def test_get_maps_json_moer(self):
moer = self.maps.get_maps_json(signal_type="co2_moer")
Expand Down Expand Up @@ -441,6 +479,22 @@ def test_region_from_loc(self):
self.assertEqual(region["region_full_name"], "Public Service Co of Colorado")
self.assertEqual(region["signal_type"], "co2_moer")

def test_my_access_in_geojson(self):
access = self.myaccess.get_access_pandas()
for signal_type in ["co2_moer", "co2_aoer", "health_damage"]:
access_regions = access.loc[
access["signal_type"] == signal_type, "region"
].unique()
maps = self.maps.get_maps_json(signal_type=signal_type)
maps_regions = [i["properties"]["region"] for i in maps["features"]]

assert (
set(access_regions) - set(maps_regions) == set()
), f"Missing regions in geojson for {signal_type}: {set(access_regions) - set(maps_regions)}"
assert (
set(maps_regions) - set(access_regions) == set()
), f"Extra regions in geojson for {signal_type}: {set(maps_regions) - set(access_regions)}"


if __name__ == "__main__":
unittest.main()
1 change: 0 additions & 1 deletion watttime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from watttime.api import *
from watttime.tcy import TCYCalculator
from watttime.util import RateLimitedRequesterMixin
Loading
Loading