From 48dcd4309bd8c791b24959a275ee19de633f4677 Mon Sep 17 00:00:00 2001 From: sam-watttime Date: Tue, 11 Feb 2025 15:14:29 -0700 Subject: [PATCH 1/8] refactor multithreading into WattTimeBase --- tests/test_sdk.py | 54 +++++++++- watttime/__init__.py | 1 - watttime/api.py | 230 +++++++++++++++++++++++++++---------------- watttime/util.py | 88 ----------------- 4 files changed, 199 insertions(+), 174 deletions(-) delete mode 100644 watttime/util.py diff --git a/tests/test_sdk.py b/tests/test_sdk.py index e3df9e4c..f08e8127 100644 --- a/tests/test_sdk.py +++ b/tests/test_sdk.py @@ -1,5 +1,6 @@ import unittest import unittest.mock as mock +from unittest.mock import patch from datetime import datetime, timedelta, date from dateutil.parser import parse from pytz import timezone, UTC @@ -47,13 +48,52 @@ def raise_for_status(self): class TestWattTimeBase(unittest.TestCase): def setUp(self): - self.base = WattTimeBase() + """Create both single-threaded and multi-threaded instances.""" + self.base = WattTimeBase(multithreaded=False, rate_limit=2) + self.base_mt = WattTimeBase(multithreaded=True, rate_limit=2) def test_login_with_real_api(self): self.base._login() assert self.base.token is not None assert self.base.token_valid_until > datetime.now() + @patch("time.sleep", return_value=None) + def test_apply_rate_limit(self, mock_sleep): + """Test _apply_rate_limit (single-threaded) triggers sleep when rate limit is exceeded.""" + # Set up a scenario with more requests than allowed: + # Our rate_limit is 2, so with 3 prior requests all within the past second, we expect a wait. + self.base._last_request_times = [0, 0.2, 0.3] + # Use a current timestamp such that all three are within the last 1 second. + ts = 0.5 + self.base._apply_rate_limit(ts) + # Expect sleep to be called with: wait_time = 1.0 - (ts - earliest_timestamp) = 1.0 - (0.5 - 0) = 0.5 seconds. + mock_sleep.assert_called_with(0.5) + + @patch.object(WattTimeBase, "_make_rate_limited_request") + def test_fetch_data_multithreaded(self, mock_make_rate_limited_request): + """Test _fetch_data_multithreaded calls _make_rate_limited_request for each request.""" + # For multi-threaded instance, simulate _make_rate_limited_request to return the passed parameters. + mock_make_rate_limited_request.side_effect = lambda url, params: { + "data": params + } + + url = f"{self.base_mt.url_base}/test_endpoint" + param_chunks = [{"param1": i} for i in range(5)] # Simulate 5 requests + + responses = self.base_mt._fetch_data_multithreaded(url, param_chunks) + + self.assertEqual(len(responses), 5) # Ensure all requests are handled + mock_make_rate_limited_request.assert_called() # Ensure it's called at least once + + expected_calls = [({"param1": i}) for i in range(5)] + actual_calls = [ + call.args[1] for call in mock_make_rate_limited_request.call_args_list + ] + + self.assertListEqual( + expected_calls, actual_calls + ) # Ensure correct params are passed + def test_parse_dates_with_string(self): start = "2022-01-01" end = "2022-01-31" @@ -120,7 +160,7 @@ def test_parse_dates_with_datetime(self): self.assertIsInstance(parsed_end, datetime) self.assertEqual(parsed_end.tzinfo, UTC) - @mock.patch("requests.post", side_effect=mocked_register) + @mock.patch("watttime.requests.Session.post", side_effect=mocked_register) def test_mock_register(self, mock_post): resp = self.base.register(email=os.getenv("WATTTIME_EMAIL")) self.assertEqual(len(mock_post.call_args_list), 1) @@ -129,6 +169,7 @@ def test_mock_register(self, mock_post): class TestWattTimeHistorical(unittest.TestCase): def setUp(self): self.historical = WattTimeHistorical() + self.historical_mt = WattTimeHistorical(multithreaded=True) def test_get_historical_jsons_3_months(self): start = "2022-01-01 00:00Z" @@ -139,6 +180,15 @@ def test_get_historical_jsons_3_months(self): self.assertGreaterEqual(len(jsons), 1) self.assertIsInstance(jsons[0], dict) + def test_get_historical_jsons_3_months_multithreaded(self): + start = "2022-01-01 00:00Z" + end = "2022-12-31 00:00Z" + jsons = self.historical_mt.get_historical_jsons(start, end, REGION) + + self.assertIsInstance(jsons, list) + self.assertGreaterEqual(len(jsons), 1) + self.assertIsInstance(jsons[0], dict) + def test_get_historical_jsons_1_week(self): start = "2022-01-01 00:00Z" end = "2022-01-07 00:00Z" diff --git a/watttime/__init__.py b/watttime/__init__.py index 15b87ab0..5930ce89 100644 --- a/watttime/__init__.py +++ b/watttime/__init__.py @@ -1,3 +1,2 @@ from watttime.api import * from watttime.tcy import TCYCalculator -from watttime.util import RateLimitedRequesterMixin \ No newline at end of file diff --git a/watttime/api.py b/watttime/api.py index 4d451dff..43ec8007 100644 --- a/watttime/api.py +++ b/watttime/api.py @@ -1,7 +1,8 @@ import os import time import threading -from datetime import date, time, datetime, timedelta +import time +from datetime import date, datetime, timedelta, time as dt_time from functools import cache from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Tuple, Union @@ -12,13 +13,17 @@ from dateutil.parser import parse from pytz import UTC -from watttime.util import RateLimitedRequesterMixin - class WattTimeBase: url_base = "https://api.watttime.org" - def __init__(self, username: Optional[str] = None, password: Optional[str] = None): + def __init__( + self, + username: Optional[str] = None, + password: Optional[str] = None, + multithreaded: bool = False, + rate_limit: int = 10, + ): """ Initializes a new instance of the class. @@ -29,8 +34,21 @@ def __init__(self, username: Optional[str] = None, password: Optional[str] = Non self.username = username or os.getenv("WATTTIME_USER") self.password = password or os.getenv("WATTTIME_PASSWORD") self.token = None + self.headers = None self.token_valid_until = None + self.multithreaded = multithreaded + self.rate_limit = rate_limit + self._last_request_times = [] + + if self.multithreaded: + self._rate_limit_lock = ( + threading.Lock() + ) # prevent multiple threads from modifying _last_request_times simultaneously + self._rate_limit_condition = threading.Condition(self._rate_limit_lock) + + self.session = requests.Session() + def _login(self): """ Login to the WattTime API, which provides a JWT valid for 30 minutes @@ -40,7 +58,7 @@ def _login(self): """ url = f"{self.url_base}/login" - rsp = requests.get( + rsp = self.session.get( url, auth=requests.auth.HTTPBasicAuth(self.username, self.password), timeout=20, @@ -50,6 +68,7 @@ def _login(self): self.token_valid_until = datetime.now() + timedelta(minutes=30) if not self.token: raise Exception("failed to log in, double check your credentials") + self.headers = {"Authorization": "Bearer " + self.token} def _is_token_valid(self) -> bool: if not self.token_valid_until: @@ -130,7 +149,7 @@ def register(self, email: str, organization: Optional[str] = None) -> None: "org": organization, } - rsp = requests.post(url, json=params, timeout=20) + rsp = self.session.post(url, json=params, timeout=20) rsp.raise_for_status() print( f"Successfully registered {self.username}, please check {email} for a verification email" @@ -158,25 +177,118 @@ def region_from_loc( Returns: Dict[str, str]: A dictionary containing the region information with keys "region" and "region_full_name". """ - if not self._is_token_valid(): - self._login() url = f"{self.url_base}/v3/region-from-loc" - headers = {"Authorization": "Bearer " + self.token} params = { "latitude": str(latitude), "longitude": str(longitude), "signal_type": signal_type, } - rsp = requests.get(url, headers=headers, params=params) - if not rsp.ok: - if rsp.status_code == 404: - # here we specifically cannot find a location that was provided - raise Exception( - f"\nAPI Response Error: {rsp.status_code}, {rsp.text} [{rsp.headers.get('x-request-id')}]" - ) - else: - rsp.raise_for_status() - return rsp.json() + j = self._make_rate_limited_request(url, params=params) + return j + + def _make_rate_limited_request(self, url: str, params: Dict[str, Any]) -> Dict: + """ + Makes a single API request while respecting the rate limit. + """ + if not self._is_token_valid() or not self.headers: + self._login() + + ts = time.time() + + # apply rate limiting by either sleeping (single thread) or + # waiting on a condition () + if self.multithreaded: + with self._rate_limit_condition: + self._apply_rate_limit(ts) + else: + self._apply_rate_limit(ts) + + try: + rsp = self.session.get(url, headers=self.headers, params=params) + rsp.raise_for_status() + j = rsp.json() + except requests.exceptions.RequestException as e: + raise RuntimeError( + f"API Request Failed: {e}\nURL: {url}\nParams: {params}" + ) from e + + if j.get("meta", {}).get("warnings"): + print("Warnings Returned: %s | Response: %s", params, j["meta"]) + + return j + + def _apply_rate_limit(self, ts: float): + """ + Rate limiting not allowing more than self.rate_limit requests per second. + + This is applied by checking is `self._last_request_times` has more than self.rate_limit entries. + If so, it will wait until the oldest entry is older than 1 second. + + If multithreading, waiting is achieved by setting a "condition" on the thread. + If single threading, we sleep for the remaining time. + """ + self._last_request_times = [t for t in self._last_request_times if ts - t < 1.0] + + if len(self._last_request_times) >= self.rate_limit: + earliest_request_age = ts - self._last_request_times[0] + wait_time = 1.0 - earliest_request_age + if wait_time > 0: + if self.multithreaded: + self._rate_limit_condition.wait(timeout=wait_time) + else: + time.sleep(wait_time) + + self._last_request_times.append(time.time()) + + if self.multithreaded: + self._rate_limit_condition.notify_all() + + def _fetch_data( + self, + url: str, + param_chunks: Union[Dict[str, Any], List[Dict[str, Any]]], + ) -> List[Dict]: + """ + Base method for fetching data without multithreading. + If you are making a single request, you can call _make_rate_limited_request directly. + This class is suited for making a series of requests in a for loop, with + varying `param_chunks`. + """ + + if isinstance(param_chunks, dict): + param_chunks = [param_chunks] + + responses = [] + for params in param_chunks: + rsp = self._make_rate_limited_request(url, params) + responses.append(rsp) + + return responses + + def _fetch_data_multithreaded( + self, url: str, param_chunks: List[Dict[str, Any]] + ) -> List[Dict]: + """ + Fetch data using multithreading with rate limiting. + + Args: + url (str): API endpoint URL. + param_chunks (List[Dict[str, Any]]): List of parameter sets. + + Returns: + List[Dict]: A list of JSON responses. + """ + responses = [] + with ThreadPoolExecutor(max_workers=os.cpu_count() * 5) as executor: + futures = { + executor.submit(self._make_rate_limited_request, url, params): params + for params in param_chunks + } + + for future in as_completed(futures): + responses.append(future.result()) + + return responses class WattTimeHistorical(WattTimeBase): @@ -207,11 +319,7 @@ def get_historical_jsons( Returns: List[dict]: A list of dictionary representations of the .json response object """ - if not self._is_token_valid(): - self._login() url = "{}/v3/historical".format(self.url_base) - headers = {"Authorization": "Bearer " + self.token} - responses = [] params = {"region": region, "signal_type": signal_type} start, end = self._parse_dates(start, end) @@ -221,20 +329,11 @@ def get_historical_jsons( if model is not None: params["model"] = model - for c in chunks: - params["start"], params["end"] = c - rsp = requests.get(url, headers=headers, params=params) - try: - rsp.raise_for_status() - j = rsp.json() - responses.append(j) - except Exception as e: - raise Exception( - f"\nAPI Response Error: {rsp.status_code}, {rsp.text} [{rsp.headers.get('x-request-id')}]" - ) - - if len(j["meta"]["warnings"]): - print("\n", "Warnings Returned:", params, j["meta"]) + param_chunks = [{**params, "start": c[0], "end": c[1]} for c in chunks] + if self.multithreaded: + responses = self._fetch_data_multithreaded(url, param_chunks) + else: + responses = self._fetch_data(url, param_chunks) # the API should not let this happen, but ensure for sanity unique_models = set([r["meta"]["model"]["date"] for r in responses]) @@ -325,13 +424,8 @@ def get_access_json(self) -> Dict: Raises: Exception: If the token is not valid. """ - if not self._is_token_valid(): - self._login() url = "{}/v3/my-access".format(self.url_base) - headers = {"Authorization": "Bearer " + self.token} - rsp = requests.get(url, headers=headers) - rsp.raise_for_status() - return rsp.json() + return self._make_rate_limited_request(url, params={}) def get_access_pandas(self) -> pd.DataFrame: """ @@ -372,17 +466,7 @@ def get_access_pandas(self) -> pd.DataFrame: return out -class WattTimeForecast(WattTimeBase, RateLimitedRequesterMixin): - def __init__( - self, - username: Optional[str] = None, - password: Optional[str] = None, - multithreaded: bool = False, - ): - super().__init__(username=username, password=password) - RateLimitedRequesterMixin.__init__(self) - self.multithreaded = multithreaded - +class WattTimeForecast(WattTimeBase): def _parse_historical_forecast_json( self, json_list: List[Dict[str, Any]] ) -> pd.DataFrame: @@ -429,8 +513,6 @@ def get_forecast_json( Returns: List[dict]: A list of dictionaries representing the forecast data in JSON format. """ - if not self._is_token_valid(): - self._login() params = { "region": region, "signal_type": signal_type, @@ -442,10 +524,7 @@ def get_forecast_json( params["model"] = model url = "{}/v3/forecast".format(self.url_base) - headers = {"Authorization": "Bearer " + self.token} - rsp = requests.get(url, headers=headers, params=params) - rsp.raise_for_status() - return rsp.json() + return self._make_rate_limited_request(url, params) def get_forecast_pandas( self, @@ -485,11 +564,7 @@ def get_historical_forecast_json( model: Optional[Union[str, date]] = None, horizon_hours: int = 24, ) -> List[Dict[str, Any]]: - if not self._is_token_valid(): - self._login() - url = f"{self.url_base}/v3/forecast/historical" - headers = {"Authorization": f"Bearer {self.token}"} params = { "region": region, "signal_type": signal_type, @@ -505,11 +580,9 @@ def get_historical_forecast_json( param_chunks = [{**params, "start": c[0], "end": c[1]} for c in chunks] if self.multithreaded: - return self._fetch_data_multithreaded(url, headers, param_chunks) + return self._fetch_data_multithreaded(url, param_chunks) else: - return [ - self._make_rate_limited_request(url, headers, p) for p in param_chunks - ] + return self._fetch_data(url, param_chunks) def get_historical_forecast_json_list( self, @@ -534,11 +607,8 @@ def get_historical_forecast_json_list( Returns: List[Dict[str, Any]]: A list of JSON responses for each requested date. """ - if not self._is_token_valid(): - self._login() url = f"{self.url_base}/v3/forecast/historical" - headers = {"Authorization": f"Bearer {self.token}"} params = { "region": region, "signal_type": signal_type, @@ -552,18 +622,16 @@ def get_historical_forecast_json_list( # add timezone to dates { **params, - "start": datetime.combine(d, time(0, 0)).isoformat() + "Z", - "end": datetime.combine(d, time(23, 59)).isoformat() + "Z", + "start": datetime.combine(d, dt_time(0, 0)).isoformat() + "Z", + "end": datetime.combine(d, dt_time(23, 59)).isoformat() + "Z", } for d in list_of_dates ] if self.multithreaded: - return self._fetch_data_multithreaded(url, headers, param_chunks) + return self._fetch_data_multithreaded(url, param_chunks) else: - return [ - self._make_rate_limited_request(url, headers, p) for p in param_chunks - ] + return self._fetch_data(url, param_chunks) def get_historical_forecast_pandas( self, @@ -642,11 +710,7 @@ def get_maps_json( Returns: dict: The JSON response from the API. """ - if not self._is_token_valid(): - self._login() + url = "{}/v3/maps".format(self.url_base) - headers = {"Authorization": "Bearer " + self.token} params = {"signal_type": signal_type} - rsp = requests.get(url, headers=headers, params=params) - rsp.raise_for_status() - return rsp.json() + return self._make_rate_limited_request(url, params) diff --git a/watttime/util.py b/watttime/util.py deleted file mode 100644 index fbcdb3a2..00000000 --- a/watttime/util.py +++ /dev/null @@ -1,88 +0,0 @@ -import threading -import time -import os -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Dict, Any, List -import requests - - -class RateLimitedRequesterMixin: - """ - Mixin to handle rate-limited multi-threaded requests. - """ - - def __init__(self, rate_limit: int = 10): - """ - Args: - rate_limit (int): Maximum number of requests per second. - """ - self._rate_limit_lock = threading.Lock() - self._last_request_times = [] - self.rate_limit = rate_limit - - def _make_rate_limited_request( - self, url: str, headers: Dict[str, str], params: Dict[str, Any] - ) -> Dict: - """ - Makes an API request with rate limiting. - - Args: - url (str): API endpoint URL. - headers (Dict[str, str]): Request headers. - params (Dict[str, Any]): Query parameters. - - Returns: - Dict: The JSON response. - """ - while True: - with self._rate_limit_lock: - current_time = time.time() - self._last_request_times = [ - t for t in self._last_request_times if current_time - t < 1.0 - ] - - if len(self._last_request_times) < self.rate_limit: - self._last_request_times.append(current_time) - break - - time.sleep(0.1) - - rsp = requests.get(url, headers=headers, params=params) - rsp.raise_for_status() - return rsp.json() - - def _fetch_data_multithreaded( - self, url: str, headers: Dict[str, str], param_chunks: List[Dict[str, Any]] - ) -> List[Dict]: - """ - Fetch data using multithreading with rate limiting. - - Args: - url (str): API endpoint URL. - headers (Dict[str, str]): Request headers. - param_chunks (List[Dict[str, Any]]): List of parameter sets. - - Returns: - List[Dict]: A list of JSON responses. - """ - responses = [] - with ThreadPoolExecutor(max_workers=os.cpu_count() * 5) as executor: - futures = { - executor.submit( - self._make_rate_limited_request, url, headers, params - ): params - for params in param_chunks - } - - for future in as_completed(futures): - try: - responses.append(future.result()) - except Exception as e: - if hasattr(e, "response"): - raise Exception( - f"\nAPI Response Error: {e.response.status_code}, {e.response.text} " - f"[{e.response.headers.get('x-request-id')}]" - ) - raise - - return responses From 9217d5c2b05e756bf9e2c0584db1442dd2eb5858 Mon Sep 17 00:00:00 2001 From: samkoebrich Date: Tue, 11 Feb 2025 15:48:08 -0700 Subject: [PATCH 2/8] combine _fetch_data and _fetch_data_multithreaded methods --- watttime/api.py | 58 ++++++++++++++----------------------------------- 1 file changed, 16 insertions(+), 42 deletions(-) diff --git a/watttime/api.py b/watttime/api.py index 43ec8007..cee69579 100644 --- a/watttime/api.py +++ b/watttime/api.py @@ -259,34 +259,19 @@ def _fetch_data( param_chunks = [param_chunks] responses = [] - for params in param_chunks: - rsp = self._make_rate_limited_request(url, params) - responses.append(rsp) - - return responses - - def _fetch_data_multithreaded( - self, url: str, param_chunks: List[Dict[str, Any]] - ) -> List[Dict]: - """ - Fetch data using multithreading with rate limiting. - - Args: - url (str): API endpoint URL. - param_chunks (List[Dict[str, Any]]): List of parameter sets. - - Returns: - List[Dict]: A list of JSON responses. - """ - responses = [] - with ThreadPoolExecutor(max_workers=os.cpu_count() * 5) as executor: - futures = { - executor.submit(self._make_rate_limited_request, url, params): params - for params in param_chunks - } - - for future in as_completed(futures): - responses.append(future.result()) + if self.multithreaded: + with ThreadPoolExecutor(max_workers=os.cpu_count() * 5) as executor: + futures = { + executor.submit(self._make_rate_limited_request, url, params): params + for params in param_chunks + } + + for future in as_completed(futures): + responses.append(future.result()) + else: + for params in param_chunks: + rsp = self._make_rate_limited_request(url, params) + responses.append(rsp) return responses @@ -330,10 +315,7 @@ def get_historical_jsons( params["model"] = model param_chunks = [{**params, "start": c[0], "end": c[1]} for c in chunks] - if self.multithreaded: - responses = self._fetch_data_multithreaded(url, param_chunks) - else: - responses = self._fetch_data(url, param_chunks) + responses = self._fetch_data(url, param_chunks) # the API should not let this happen, but ensure for sanity unique_models = set([r["meta"]["model"]["date"] for r in responses]) @@ -578,11 +560,7 @@ def get_historical_forecast_json( params["model"] = model param_chunks = [{**params, "start": c[0], "end": c[1]} for c in chunks] - - if self.multithreaded: - return self._fetch_data_multithreaded(url, param_chunks) - else: - return self._fetch_data(url, param_chunks) + return self._fetch_data(url, param_chunks) def get_historical_forecast_json_list( self, @@ -627,11 +605,7 @@ def get_historical_forecast_json_list( } for d in list_of_dates ] - - if self.multithreaded: - return self._fetch_data_multithreaded(url, param_chunks) - else: - return self._fetch_data(url, param_chunks) + return self._fetch_data(url, param_chunks) def get_historical_forecast_pandas( self, From 67fe635d1ac0b0a5509a77025ffa3470647bdb34 Mon Sep 17 00:00:00 2001 From: samkoebrich Date: Tue, 11 Feb 2025 15:50:03 -0700 Subject: [PATCH 3/8] remove test for seperate _fetch_data_multithreaded method --- tests/test_sdk.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/tests/test_sdk.py b/tests/test_sdk.py index f08e8127..f8bd74ca 100644 --- a/tests/test_sdk.py +++ b/tests/test_sdk.py @@ -69,31 +69,6 @@ def test_apply_rate_limit(self, mock_sleep): # Expect sleep to be called with: wait_time = 1.0 - (ts - earliest_timestamp) = 1.0 - (0.5 - 0) = 0.5 seconds. mock_sleep.assert_called_with(0.5) - @patch.object(WattTimeBase, "_make_rate_limited_request") - def test_fetch_data_multithreaded(self, mock_make_rate_limited_request): - """Test _fetch_data_multithreaded calls _make_rate_limited_request for each request.""" - # For multi-threaded instance, simulate _make_rate_limited_request to return the passed parameters. - mock_make_rate_limited_request.side_effect = lambda url, params: { - "data": params - } - - url = f"{self.base_mt.url_base}/test_endpoint" - param_chunks = [{"param1": i} for i in range(5)] # Simulate 5 requests - - responses = self.base_mt._fetch_data_multithreaded(url, param_chunks) - - self.assertEqual(len(responses), 5) # Ensure all requests are handled - mock_make_rate_limited_request.assert_called() # Ensure it's called at least once - - expected_calls = [({"param1": i}) for i in range(5)] - actual_calls = [ - call.args[1] for call in mock_make_rate_limited_request.call_args_list - ] - - self.assertListEqual( - expected_calls, actual_calls - ) # Ensure correct params are passed - def test_parse_dates_with_string(self): start = "2022-01-01" end = "2022-01-31" From eeb20bb8731286073fff821ae660446e183893f6 Mon Sep 17 00:00:00 2001 From: samkoebrich Date: Wed, 12 Feb 2025 10:36:15 -0700 Subject: [PATCH 4/8] do to expose credentials as persistent attributes --- tests/test_sdk.py | 12 ++++++++++++ watttime/api.py | 35 +++++++++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/tests/test_sdk.py b/tests/test_sdk.py index e3df9e4c..e5adbb9c 100644 --- a/tests/test_sdk.py +++ b/tests/test_sdk.py @@ -125,6 +125,18 @@ def test_mock_register(self, mock_post): resp = self.base.register(email=os.getenv("WATTTIME_EMAIL")) self.assertEqual(len(mock_post.call_args_list), 1) + def test_get_password(self): + + with mock.patch.dict(os.environ, {}, clear=True), self.assertRaises(ValueError): + wt_base = WattTimeBase() + + with mock.patch.dict(os.environ, {}, clear=True): + wt_base = WattTimeBase( + username="WATTTIME_USERNAME", password="WATTTIME_PASSWORD" + ) + self.assertEqual(wt_base.username, "WATTTIME_USERNAME") + self.assertEqual(wt_base.password, "WATTTIME_PASSWORD") + class TestWattTimeHistorical(unittest.TestCase): def setUp(self): diff --git a/watttime/api.py b/watttime/api.py index 4d451dff..bbcba273 100644 --- a/watttime/api.py +++ b/watttime/api.py @@ -26,11 +26,42 @@ def __init__(self, username: Optional[str] = None, password: Optional[str] = Non username (Optional[str]): The username to use for authentication. If not provided, the value will be retrieved from the environment variable "WATTTIME_USER". password (Optional[str]): The password to use for authentication. If not provided, the value will be retrieved from the environment variable "WATTTIME_PASSWORD". """ - self.username = username or os.getenv("WATTTIME_USER") - self.password = password or os.getenv("WATTTIME_PASSWORD") + + # This only applies to the current session, is not stored persistently + if username and not os.getenv("WATTTIME_USER"): + os.environ["WATTTIME_USER"] = username + if password and not os.getenv("WATTTIME_PASSWORD"): + os.environ["WATTTIME_PASSWORD"] = password + + # Accessing attributes will raise exception if variables are not set + _ = self.password + _ = self.username + self.token = None self.token_valid_until = None + @property + def password(self): + password = os.getenv("WATTTIME_PASSWORD") + if not password: + raise ValueError( + "WATTTIME_PASSWORD env variable is not set." + + "Please set this variable, or pass in a password upon initialization," + + "which will store it as a variable only for the current session" + ) + return password + + @property + def username(self): + username = os.getenv("WATTTIME_USER") + if not username: + raise ValueError( + "WATTTIME_USER env variable is not set." + + "Please set this variable, or pass in a username upon initialization," + + "which will store it as a variable only for the current session" + ) + return username + def _login(self): """ Login to the WattTime API, which provides a JWT valid for 30 minutes From 26decacc593eb0d0be3cf2ef32c7b9e110568e9c Mon Sep 17 00:00:00 2001 From: samkoebrich Date: Wed, 12 Feb 2025 16:15:23 -0700 Subject: [PATCH 5/8] remove use of Session for now, will move into seperate PR --- watttime/api.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/watttime/api.py b/watttime/api.py index cee69579..b70da778 100644 --- a/watttime/api.py +++ b/watttime/api.py @@ -47,8 +47,6 @@ def __init__( ) # prevent multiple threads from modifying _last_request_times simultaneously self._rate_limit_condition = threading.Condition(self._rate_limit_lock) - self.session = requests.Session() - def _login(self): """ Login to the WattTime API, which provides a JWT valid for 30 minutes @@ -58,7 +56,7 @@ def _login(self): """ url = f"{self.url_base}/login" - rsp = self.session.get( + rsp = requests.get( url, auth=requests.auth.HTTPBasicAuth(self.username, self.password), timeout=20, @@ -149,7 +147,7 @@ def register(self, email: str, organization: Optional[str] = None) -> None: "org": organization, } - rsp = self.session.post(url, json=params, timeout=20) + rsp = requests.post(url, json=params, timeout=20) rsp.raise_for_status() print( f"Successfully registered {self.username}, please check {email} for a verification email" @@ -204,7 +202,7 @@ def _make_rate_limited_request(self, url: str, params: Dict[str, Any]) -> Dict: self._apply_rate_limit(ts) try: - rsp = self.session.get(url, headers=self.headers, params=params) + rsp = requests.get(url, headers=self.headers, params=params) rsp.raise_for_status() j = rsp.json() except requests.exceptions.RequestException as e: @@ -262,7 +260,9 @@ def _fetch_data( if self.multithreaded: with ThreadPoolExecutor(max_workers=os.cpu_count() * 5) as executor: futures = { - executor.submit(self._make_rate_limited_request, url, params): params + executor.submit( + self._make_rate_limited_request, url, params + ): params for params in param_chunks } From ae09336e7a5394da4aa3cf7ecb226b555265198e Mon Sep 17 00:00:00 2001 From: Joel Cofield Date: Wed, 12 Feb 2025 21:56:38 -0800 Subject: [PATCH 6/8] track _last_request_meta in request method --- watttime/api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/watttime/api.py b/watttime/api.py index b70da778..ea8691a4 100644 --- a/watttime/api.py +++ b/watttime/api.py @@ -213,6 +213,8 @@ def _make_rate_limited_request(self, url: str, params: Dict[str, Any]) -> Dict: if j.get("meta", {}).get("warnings"): print("Warnings Returned: %s | Response: %s", params, j["meta"]) + self._last_request_meta = j.get("meta", {}) + return j def _apply_rate_limit(self, ts: float): From 5e11102b8ca2e081c379fbaf34f7a21adc9e2834 Mon Sep 17 00:00:00 2001 From: sam-watttime <75635755+sam-watttime@users.noreply.github.com> Date: Wed, 12 Mar 2025 10:10:44 -0600 Subject: [PATCH 7/8] Test for every region in my-access in maps geojson (#40) --- tests/test_sdk.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_sdk.py b/tests/test_sdk.py index f8bd74ca..c8002259 100644 --- a/tests/test_sdk.py +++ b/tests/test_sdk.py @@ -430,6 +430,7 @@ def test_horizon_hours(self): class TestWattTimeMaps(unittest.TestCase): def setUp(self): self.maps = WattTimeMaps() + self.myaccess = WattTimeMyAccess() def test_get_maps_json_moer(self): moer = self.maps.get_maps_json(signal_type="co2_moer") @@ -466,6 +467,22 @@ def test_region_from_loc(self): self.assertEqual(region["region_full_name"], "Public Service Co of Colorado") self.assertEqual(region["signal_type"], "co2_moer") + def test_my_access_in_geojson(self): + access = self.myaccess.get_access_pandas() + for signal_type in ["co2_moer", "co2_aoer", "health_damage"]: + access_regions = access.loc[ + access["signal_type"] == signal_type, "region" + ].unique() + maps = self.maps.get_maps_json(signal_type=signal_type) + maps_regions = [i["properties"]["region"] for i in maps["features"]] + + assert ( + set(access_regions) - set(maps_regions) == set() + ), f"Missing regions in geojson for {signal_type}: {set(access_regions) - set(maps_regions)}" + assert ( + set(maps_regions) - set(access_regions) == set() + ), f"Extra regions in geojson for {signal_type}: {set(maps_regions) - set(access_regions)}" + if __name__ == "__main__": unittest.main() From 425c2e316f3e77f95117b110d42cc5dd15613e7c Mon Sep 17 00:00:00 2001 From: sam-watttime <75635755+sam-watttime@users.noreply.github.com> Date: Wed, 12 Mar 2025 10:11:35 -0600 Subject: [PATCH 8/8] refactor multithreading into WattTimeBase (#35) --- tests/test_sdk.py | 2 +- watttime/api.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_sdk.py b/tests/test_sdk.py index c8002259..d5dfa2a9 100644 --- a/tests/test_sdk.py +++ b/tests/test_sdk.py @@ -135,7 +135,7 @@ def test_parse_dates_with_datetime(self): self.assertIsInstance(parsed_end, datetime) self.assertEqual(parsed_end.tzinfo, UTC) - @mock.patch("watttime.requests.Session.post", side_effect=mocked_register) + @mock.patch("requests.post", side_effect=mocked_register) def test_mock_register(self, mock_post): resp = self.base.register(email=os.getenv("WATTTIME_EMAIL")) self.assertEqual(len(mock_post.call_args_list), 1) diff --git a/watttime/api.py b/watttime/api.py index ea8691a4..59bf9965 100644 --- a/watttime/api.py +++ b/watttime/api.py @@ -260,7 +260,7 @@ def _fetch_data( responses = [] if self.multithreaded: - with ThreadPoolExecutor(max_workers=os.cpu_count() * 5) as executor: + with ThreadPoolExecutor(max_workers=10) as executor: futures = { executor.submit( self._make_rate_limited_request, url, params