From 3a5bec9c99319a7cc66b45686933b68a9faf829c Mon Sep 17 00:00:00 2001 From: sixtysxx Date: Wed, 25 Jun 2025 07:07:17 -0600 Subject: [PATCH] fixed some pylance errors and import errors --- pocketoptionapi_async/client.py | 49 ++++++------------- .../connection_keep_alive.py | 34 ++++++++++--- pocketoptionapi_async/connection_monitor.py | 4 +- pocketoptionapi_async/constants.py | 4 +- pocketoptionapi_async/exceptions.py | 4 +- pocketoptionapi_async/monitoring.py | 22 ++++++--- pocketoptionapi_async/websocket_client.py | 38 ++++++++++---- tests/advanced_testing_suite.py | 2 +- tests/integration_tests.py | 6 +-- tests/test_order_placement_fix.py | 3 +- tools/client_test.py | 15 +++--- tools/get_ssid.py | 8 ++- 12 files changed, 113 insertions(+), 76 deletions(-) diff --git a/pocketoptionapi_async/client.py b/pocketoptionapi_async/client.py index ee1cc8d..8c2b335 100644 --- a/pocketoptionapi_async/client.py +++ b/pocketoptionapi_async/client.py @@ -147,7 +147,7 @@ def _setup_event_handlers(self): self._websocket.add_event_handler("disconnected", self._on_disconnected) async def connect( - self, regions: Optional[List[str]] = None, persistent: bool = None + self, regions: Optional[List[str]] = None, persistent: Optional[bool] = None ) -> bool: """ Connect to PocketOption with multiple region support @@ -162,7 +162,7 @@ async def connect( logger.info("Connecting to PocketOption...") # Update persistent setting if provided if persistent is not None: - self.persistent_connection = persistent + self.persistent_connection = bool(persistent) try: if self.persistent_connection: @@ -1181,39 +1181,31 @@ async def _on_candles_received(self, data: Dict[str, Any]) -> None: logger.info(f"🕯️ Candles received with data: {type(data)}") # Check if we have pending candle requests if hasattr(self, "_candle_requests") and self._candle_requests: - # Parse the candles data try: - # Get the first pending request to extract asset and timeframe info for request_id, future in list(self._candle_requests.items()): if not future.done(): - # Extract asset and timeframe from request_id format: "asset_timeframe" parts = request_id.split("_") if len(parts) >= 2: - asset = "_".join( - parts[:-1] - ) # Handle assets with underscores + asset = "_".join(parts[:-1]) timeframe = int(parts[-1]) - - candles = self._parse_candles_data(data, asset, timeframe) + candles = self._parse_candles_data( + data.get("candles", []), asset, timeframe + ) if self.enable_logging: logger.info( f"🕯️ Parsed {len(candles)} candles from response" ) - future.set_result(candles) if self.enable_logging: logger.debug(f"Resolved candle request: {request_id}") break - except Exception as e: if self.enable_logging: logger.error(f"Error processing candles data: {e}") - # Resolve futures with empty result for request_id, future in list(self._candle_requests.items()): if not future.done(): future.set_result([]) break - await self._emit_event("candles_received", data) async def _on_disconnected(self, data: Dict[str, Any]) -> None: @@ -1227,51 +1219,39 @@ async def _handle_candles_stream(self, data: Dict[str, Any]) -> None: try: asset = data.get("asset") period = data.get("period") - if not asset or not period: return - request_id = f"{asset}_{period}" - if self.enable_logging: logger.info(f"🕯️ Processing candle stream for {asset} ({period}s)") - - # Check if we have a pending request for this asset/period if ( hasattr(self, "_candle_requests") and request_id in self._candle_requests ): future = self._candle_requests[request_id] - if not future.done(): - # Parse candles from stream data - candles = self._parse_stream_candles(data) + candles = self._parse_stream_candles(data, asset, period) if candles: future.set_result(candles) if self.enable_logging: logger.info( f"🕯️ Resolved candle request for {asset} with {len(candles)} candles" ) - - # Clean up the request del self._candle_requests[request_id] - except Exception as e: if self.enable_logging: logger.error(f"Error handling candles stream: {e}") - def _parse_stream_candles(self, stream_data: Dict[str, Any]): + def _parse_stream_candles( + self, stream_data: Dict[str, Any], asset: str, timeframe: int + ): """Parse candles from stream update data (changeSymbol response)""" candles = [] - try: - # Stream data might contain candles in different formats candle_data = stream_data.get("data") or stream_data.get("candles") or [] - if isinstance(candle_data, list): for item in candle_data: if isinstance(item, dict): - # Dict format candle = Candle( timestamp=datetime.fromtimestamp(item.get("time", 0)), open=float(item.get("open", 0)), @@ -1279,10 +1259,11 @@ def _parse_stream_candles(self, stream_data: Dict[str, Any]): low=float(item.get("low", 0)), close=float(item.get("close", 0)), volume=float(item.get("volume", 0)), + asset=asset, + timeframe=timeframe, ) candles.append(candle) elif isinstance(item, (list, tuple)) and len(item) >= 6: - # Array format: [timestamp, open, close, high, low, volume] candle = Candle( timestamp=datetime.fromtimestamp(item[0]), open=float(item[1]), @@ -1290,16 +1271,14 @@ def _parse_stream_candles(self, stream_data: Dict[str, Any]): low=float(item[4]), close=float(item[2]), volume=float(item[5]) if len(item) > 5 else 0.0, + asset=asset, + timeframe=timeframe, ) candles.append(candle) - - # Sort by timestamp candles.sort(key=lambda x: x.timestamp) - except Exception as e: if self.enable_logging: logger.error(f"Error parsing stream candles: {e}") - return candles async def _on_keep_alive_connected(self): diff --git a/pocketoptionapi_async/connection_keep_alive.py b/pocketoptionapi_async/connection_keep_alive.py index 5705fd6..23d4b00 100644 --- a/pocketoptionapi_async/connection_keep_alive.py +++ b/pocketoptionapi_async/connection_keep_alive.py @@ -6,8 +6,8 @@ from typing import Optional, List, Callable, Dict, Any from datetime import datetime, timedelta from loguru import logger -import websockets from websockets.exceptions import ConnectionClosed +from websockets.legacy.client import connect, WebSocketClientProtocol from models import ConnectionInfo, ConnectionStatus from constants import REGIONS @@ -23,7 +23,7 @@ def __init__(self, ssid: str, is_demo: bool = True): self.is_demo = is_demo # Connection state - self.websocket: Optional[websockets.WebSocketServerProtocol] = None + self.websocket: Optional[WebSocketClientProtocol] = None self.connection_info: Optional[ConnectionInfo] = None self.is_connected = False self.should_reconnect = True @@ -138,7 +138,7 @@ async def _establish_connection(self) -> bool: # Connect with headers (like old API) self.websocket = await asyncio.wait_for( - websockets.connect( + connect( url, ssl=ssl_context, extra_headers={ @@ -198,6 +198,8 @@ async def _establish_connection(self) -> bool: async def _send_handshake(self): """Send initial handshake sequence (like old API)""" try: + if not self.websocket: + raise RuntimeError("Handshake called with no websocket connection.") # Wait for initial connection message initial_message = await asyncio.wait_for( self.websocket.recv(), timeout=10.0 @@ -408,9 +410,10 @@ async def _process_message(self, message): # Handle ping-pong (like old API) if message == "2": - await self.websocket.send("3") - self.connection_stats["last_pong_time"] = datetime.now() - logger.debug("Ping: Pong sent") + if self.websocket: + await self.websocket.send("3") + self.connection_stats["last_pong_time"] = datetime.now() + logger.debug("Ping: Pong sent") return # Handle authentication success (like old API) @@ -490,6 +493,25 @@ def get_connection_stats(self) -> Dict[str, Any]: "available_regions": len(self.available_urls), } + async def connect_with_keep_alive( + self, regions: Optional[List[str]] = None + ) -> bool: + """Establish a persistent connection with keep-alive, optionally using a list of regions.""" + # Optionally update available_urls if regions are provided + if regions: + # Assume regions are URLs or region names; adapt as needed + self.available_urls = regions + self.current_url_index = 0 + return await self.start_persistent_connection() + + async def disconnect(self) -> None: + """Disconnect and clean up persistent connection.""" + await self.stop_persistent_connection() + + def get_stats(self) -> Dict[str, Any]: + """Return connection statistics (alias for get_connection_stats).""" + return self.get_connection_stats() + async def demo_keep_alive(): """Demo of the keep-alive connection manager""" diff --git a/pocketoptionapi_async/connection_monitor.py b/pocketoptionapi_async/connection_monitor.py index e3e87b6..1aaa64c 100644 --- a/pocketoptionapi_async/connection_monitor.py +++ b/pocketoptionapi_async/connection_monitor.py @@ -602,7 +602,7 @@ def generate_diagnostics_report(self) -> Dict[str, Any]: return report - def export_metrics_csv(self, filename: str = None) -> str: + def export_metrics_csv(self, filename: str = "") -> str: """Export metrics to CSV file""" if not filename: filename = f"metrics_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" @@ -723,7 +723,7 @@ async def _display_loop(self): await asyncio.sleep(1) -async def run_monitoring_demo(ssid: str = None): +async def run_monitoring_demo(ssid: Optional[str] = None): """Run monitoring demonstration""" if not ssid: diff --git a/pocketoptionapi_async/constants.py b/pocketoptionapi_async/constants.py index 8080d9a..c9077cb 100644 --- a/pocketoptionapi_async/constants.py +++ b/pocketoptionapi_async/constants.py @@ -178,8 +178,10 @@ def get_all_regions(cls) -> Dict[str, str]: """Get all regions as a dictionary""" return cls._REGIONS.copy() + from typing import Optional + @classmethod - def get_region(cls, region_name: str) -> str: + def get_region(cls, region_name: str) -> Optional[str]: """Get specific region URL""" return cls._REGIONS.get(region_name.upper()) diff --git a/pocketoptionapi_async/exceptions.py b/pocketoptionapi_async/exceptions.py index f8212ac..d4759ff 100644 --- a/pocketoptionapi_async/exceptions.py +++ b/pocketoptionapi_async/exceptions.py @@ -6,7 +6,9 @@ class PocketOptionError(Exception): """Base exception for all PocketOption API errors""" - def __init__(self, message: str, error_code: str = None): + from typing import Optional + + def __init__(self, message: str, error_code: Optional[str] = None): super().__init__(message) self.message = message self.error_code = error_code diff --git a/pocketoptionapi_async/monitoring.py b/pocketoptionapi_async/monitoring.py index accbb86..2168422 100644 --- a/pocketoptionapi_async/monitoring.py +++ b/pocketoptionapi_async/monitoring.py @@ -63,11 +63,13 @@ class PerformanceMetrics: class CircuitBreaker: """Circuit breaker pattern implementation""" + from typing import Type + def __init__( self, failure_threshold: int = 5, recovery_timeout: int = 60, - expected_exception: type = Exception, + expected_exception: Type[BaseException] = Exception, ): self.failure_threshold = failure_threshold self.recovery_timeout = recovery_timeout @@ -79,7 +81,10 @@ def __init__( async def call(self, func: Callable, *args, **kwargs): """Execute function with circuit breaker protection""" if self.state == "OPEN": - if time.time() - self.last_failure_time < self.recovery_timeout: + if ( + self.last_failure_time is not None + and time.time() - self.last_failure_time < self.recovery_timeout + ): raise Exception("Circuit breaker is OPEN") else: self.state = "HALF_OPEN" @@ -155,7 +160,10 @@ async def execute(self, func: Callable, *args, **kwargs): ) await asyncio.sleep(delay) - raise last_exception + if last_exception is not None: + raise last_exception + else: + raise Exception("RetryPolicy failed but no exception was captured.") class ErrorMonitor: @@ -197,8 +205,8 @@ async def record_error( severity: ErrorSeverity, category: ErrorCategory, message: str, - context: Dict[str, Any] = None, - stack_trace: str = None, + context: Optional[Dict[str, Any]] = None, + stack_trace: Optional[str] = None, ): """Record an error event""" error_event = ErrorEvent( @@ -208,7 +216,7 @@ async def record_error( category=category, message=message, context=context or {}, - stack_trace=stack_trace, + stack_trace=stack_trace or "", ) self.errors.append(error_event) @@ -340,7 +348,7 @@ async def execute_with_monitoring( "args": str(args)[:200], # Truncate for security "kwargs": str({k: str(v)[:100] for k, v in kwargs.items()})[:200], }, - stack_trace=None, # Could add traceback.format_exc() here + stack_trace="", # Could add traceback.format_exc() here ) raise e diff --git a/pocketoptionapi_async/websocket_client.py b/pocketoptionapi_async/websocket_client.py index 7d54d75..094bf89 100644 --- a/pocketoptionapi_async/websocket_client.py +++ b/pocketoptionapi_async/websocket_client.py @@ -11,6 +11,7 @@ from collections import deque import websockets from websockets.exceptions import ConnectionClosed +from websockets.legacy.client import WebSocketClientProtocol from loguru import logger from .models import ConnectionInfo, ConnectionStatus, ServerTime @@ -61,8 +62,7 @@ class ConnectionPool: """Connection pool for better resource management""" def __init__(self, max_connections: int = 3): - self.max_connections = max_connections - self.active_connections: Dict[str, websockets.WebSocketServerProtocol] = {} + self.active_connections: Dict[str, WebSocketClientProtocol] = {} self.connection_stats: Dict[str, Dict[str, Any]] = {} self._pool_lock = asyncio.Lock() @@ -119,7 +119,7 @@ class AsyncWebSocketClient: """ def __init__(self): - self.websocket: Optional[websockets.WebSocketServerProtocol] = None + self.websocket: Optional[WebSocketClientProtocol] = None self.connection_info: Optional[ConnectionInfo] = None self.server_time: Optional[ServerTime] = None self._ping_task: Optional[asyncio.Task] = None @@ -166,7 +166,7 @@ async def connect(self, urls: List[str], ssid: str) -> bool: ssl_context.verify_mode = ssl.CERT_NONE # Connect with timeout - self.websocket = await asyncio.wait_for( + ws = await asyncio.wait_for( websockets.connect( url, ssl=ssl_context, @@ -177,6 +177,7 @@ async def connect(self, urls: List[str], ssid: str) -> bool: ), timeout=10.0, ) + self.websocket = ws # type: ignore # Update connection info region = self._extract_region_from_url(url) self.connection_info = ConnectionInfo( @@ -349,11 +350,19 @@ async def _send_handshake(self, ssid: str) -> None: try: # Wait for initial connection message with "0" and "sid" (like old API) logger.debug("Waiting for initial handshake message...") + if not self.websocket: + raise WebSocketError("WebSocket is not connected during handshake") initial_message = await asyncio.wait_for( self.websocket.recv(), timeout=10.0 ) logger.debug(f"Received initial: {initial_message}") + # Ensure initial_message is a string + if isinstance(initial_message, memoryview): + initial_message = bytes(initial_message).decode("utf-8") + elif isinstance(initial_message, (bytes, bytearray)): + initial_message = initial_message.decode("utf-8") + # Check if it's the expected initial message format if initial_message.startswith("0") and "sid" in initial_message: # Send "40" response (like old API) @@ -366,8 +375,14 @@ async def _send_handshake(self, ssid: str) -> None: ) logger.debug(f"Received connection: {conn_message}") - # Check if it's the expected connection message format - if conn_message.startswith("40") and "sid" in conn_message: + # Ensure conn_message is a string + if isinstance(conn_message, memoryview): + conn_message_str = bytes(conn_message).decode("utf-8") + elif isinstance(conn_message, (bytes, bytearray)): + conn_message_str = conn_message.decode("utf-8") + else: + conn_message_str = conn_message + if conn_message_str.startswith("40") and "sid" in conn_message_str: # Send SSID authentication (like old API) await self.send_message(ssid) logger.debug("Sent SSID authentication") @@ -538,7 +553,7 @@ async def _process_message_optimized(self, message) -> None: if cached_time and time.time() - cached_time < self._cache_ttl: # Use cached processing result - cached_result = self._message_cache.get(message_hash) + cached_result = self._message_cache.get(str(message_hash)) if cached_result: await self._emit_event("cached_message", cached_result) return @@ -553,7 +568,10 @@ async def _process_message_optimized(self, message) -> None: logger.warning(f"Unknown message type: {message[:20]}...") # Cache processing result - self._message_cache[message_hash] = {"processed": True, "type": "unknown"} + self._message_cache[str(message_hash)] = { + "processed": True, + "type": "unknown", + } self._message_cache[f"{message_hash}_time"] = time.time() except Exception as e: @@ -651,7 +669,7 @@ def _extract_region_from_url(self, url: str) -> str: return "DEMO" else: return "UNKNOWN" - except: + except Exception: return "UNKNOWN" @property @@ -660,6 +678,6 @@ def is_connected(self) -> bool: return ( self.websocket is not None and not self.websocket.closed - and self.connection_info + and self.connection_info is not None and self.connection_info.status == ConnectionStatus.CONNECTED ) diff --git a/tests/advanced_testing_suite.py b/tests/advanced_testing_suite.py index 58e8782..7ce11e6 100644 --- a/tests/advanced_testing_suite.py +++ b/tests/advanced_testing_suite.py @@ -13,7 +13,7 @@ from pocketoptionapi_async.client import AsyncPocketOptionClient from pocketoptionapi_async.models import OrderDirection, TimeFrame -from connection_keep_alive import ConnectionKeepAlive +from pocketoptionapi_async.connection_keep_alive import ConnectionKeepAlive class AdvancedTestSuite: diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 2ca0147..942bf97 100644 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -12,9 +12,9 @@ from pocketoptionapi_async.client import AsyncPocketOptionClient from pocketoptionapi_async.models import TimeFrame -from connection_keep_alive import ConnectionKeepAlive -from connection_monitor import ConnectionMonitor -from load_testing_tool import LoadTester, LoadTestConfig +from pocketoptionapi_async.connection_keep_alive import ConnectionKeepAlive +from pocketoptionapi_async.connection_monitor import ConnectionMonitor +from performance.load_testing_tool import LoadTester, LoadTestConfig class IntegrationTester: diff --git a/tests/test_order_placement_fix.py b/tests/test_order_placement_fix.py index 6d30569..fee41f1 100644 --- a/tests/test_order_placement_fix.py +++ b/tests/test_order_placement_fix.py @@ -5,6 +5,7 @@ import asyncio import os from loguru import logger +from pocketoptionapi_async import AsyncPocketOptionClient, OrderDirection # Configure logger logger.remove() @@ -14,8 +15,6 @@ level="INFO", ) -from pocketoptionapi_async import AsyncPocketOptionClient, OrderDirection - async def test_order_placement_fix(): """Test the order placement fix""" diff --git a/tools/client_test.py b/tools/client_test.py index 213c94a..39ff86a 100644 --- a/tools/client_test.py +++ b/tools/client_test.py @@ -1,19 +1,21 @@ import websockets import anyio from rich.pretty import pprint as print -from pocketoptionapi.constants import REGION +from pocketoptionapi_async.constants import REGIONS SESSION = r'42["auth",{"session":"a:4:{s:10:\"session_id\";s:32:\"a1dc009a7f1f0c8267d940d0a036156f\";s:10:\"ip_address\";s:12:\"190.162.4.33\";s:10:\"user_agent\";s:120:\"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 OP\";s:13:\"last_activity\";i:1709914958;}793884e7bccc89ec798c06ef1279fcf2","isDemo":0,"uid":27658142,"platform":1}]' async def websocket_client(url, pro): - for i in REGION.get_regions(REGION): + # Use REGIONS.get_all() to get a list of region URLs + region_urls = REGIONS.get_all() + for i in region_urls: print(f"Trying {i}...") try: async with websockets.connect( i, extra_headers={ - "Origin": "https://pocketoption.com/" # main URL + "Origin": "https://pocketoption.com/" # main URL }, ) as websocket: async for message in websocket: @@ -23,14 +25,13 @@ async def websocket_client(url, pro): except Exception as e: print(e) print("Connection lost... reconnecting") - # await anyio.sleep(5) return True async def pro(message, websocket, url): - # if byte data - if type(message) == type(b""): - # cut 100 first symbols of byte date to prevent spam + # Use isinstance for type checking + if isinstance(message, bytes): + # cut 100 first symbols of byte data to prevent spam print(str(message)[:100]) return else: diff --git a/tools/get_ssid.py b/tools/get_ssid.py index 3403302..cf0efb5 100644 --- a/tools/get_ssid.py +++ b/tools/get_ssid.py @@ -3,6 +3,7 @@ import time import re import logging +from typing import cast, List, Dict, Any from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC from driver import get_driver @@ -94,7 +95,12 @@ def get_pocketoption_ssid(): # Retrieve performance logs which include network requests and WebSocket frames. # These logs are crucial for capturing the raw WebSocket messages. - performance_logs = driver.get_log("performance") + get_log = getattr(driver, "get_log", None) + if not callable(get_log): + raise AttributeError( + "Your WebDriver does not support get_log(). Make sure you are using Chrome with performance logging enabled." + ) + performance_logs = cast(List[Dict[str, Any]], get_log("performance")) logger.info(f"Collected {len(performance_logs)} performance log entries.") found_full_ssid_string = None