diff --git a/tests/test_sdk.py b/tests/test_sdk.py index 4a7f952b..10cddb0a 100644 --- a/tests/test_sdk.py +++ b/tests/test_sdk.py @@ -143,6 +143,18 @@ def test_mock_register(self, mock_post): resp = self.base.register(email=os.getenv("WATTTIME_EMAIL")) self.assertEqual(len(mock_post.call_args_list), 1) + def test_get_password(self): + + with mock.patch.dict(os.environ, {}, clear=True), self.assertRaises(ValueError): + wt_base = WattTimeBase() + + with mock.patch.dict(os.environ, {}, clear=True): + wt_base = WattTimeBase( + username="WATTTIME_USERNAME", password="WATTTIME_PASSWORD" + ) + self.assertEqual(wt_base.username, "WATTTIME_USERNAME") + self.assertEqual(wt_base.password, "WATTTIME_PASSWORD") + class TestWattTimeHistorical(unittest.TestCase): def setUp(self): @@ -464,6 +476,11 @@ def test_historical_forecast_jsons_multithreaded(self): class TestWattTimeMaps(unittest.TestCase): def setUp(self): self.maps = WattTimeMaps() + self.myaccess = WattTimeMyAccess() + + def tearDown(self): + self.maps.session.close() + self.myaccess.session.close() def tearDown(self): self.maps.session.close() @@ -503,6 +520,22 @@ def test_region_from_loc(self): self.assertEqual(region["region_full_name"], "Public Service Co of Colorado") self.assertEqual(region["signal_type"], "co2_moer") + def test_my_access_in_geojson(self): + access = self.myaccess.get_access_pandas() + for signal_type in ["co2_moer", "co2_aoer", "health_damage"]: + access_regions = access.loc[ + access["signal_type"] == signal_type, "region" + ].unique() + maps = self.maps.get_maps_json(signal_type=signal_type) + maps_regions = [i["properties"]["region"] for i in maps["features"]] + + assert ( + set(access_regions) - set(maps_regions) == set() + ), f"Missing regions in geojson for {signal_type}: {set(access_regions) - set(maps_regions)}" + assert ( + set(maps_regions) - set(access_regions) == set() + ), f"Extra regions in geojson for {signal_type}: {set(maps_regions) - set(access_regions)}" + class TestWattTimeMarginalFuelMix(unittest.TestCase): def setUp(self): diff --git a/watttime/api.py b/watttime/api.py index b84b575e..8be1113e 100644 --- a/watttime/api.py +++ b/watttime/api.py @@ -37,8 +37,17 @@ def __init__( worker_count (int): The number of worker threads to use for multithreading. Default is min(10, (os.cpu_count() or 1) * 2). """ - self.username = username or os.getenv("WATTTIME_USER") - self.password = password or os.getenv("WATTTIME_PASSWORD") + + # This only applies to the current session, is not stored persistently + if username and not os.getenv("WATTTIME_USER"): + os.environ["WATTTIME_USER"] = username + if password and not os.getenv("WATTTIME_PASSWORD"): + os.environ["WATTTIME_PASSWORD"] = password + + # Accessing attributes will raise exception if variables are not set + _ = self.password + _ = self.username + self.token = None self.headers = None self.token_valid_until = None @@ -56,6 +65,28 @@ def __init__( self.session = requests.Session() + @property + def password(self): + password = os.getenv("WATTTIME_PASSWORD") + if not password: + raise ValueError( + "WATTTIME_PASSWORD env variable is not set." + + "Please set this variable, or pass in a password upon initialization," + + "which will store it as a variable only for the current session" + ) + return password + + @property + def username(self): + username = os.getenv("WATTTIME_USER") + if not username: + raise ValueError( + "WATTTIME_USER env variable is not set." + + "Please set this variable, or pass in a username upon initialization," + + "which will store it as a variable only for the current session" + ) + return username + def _login(self): """ Login to the WattTime API, which provides a JWT valid for 30 minutes