Skip to content

Commit 01a1c21

Browse files
authored
Change execute_service to be async to handle timeouts internally (#1443)
1 parent 71362b5 commit 01a1c21

File tree

2 files changed

+90
-101
lines changed

2 files changed

+90
-101
lines changed

aioesphomeapi/client.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@
169169

170170
DEFAULT_BLE_TIMEOUT = 30.0
171171
DEFAULT_BLE_DISCONNECT_TIMEOUT = 20.0
172+
DEFAULT_EXECUTE_SERVICE_TIMEOUT = 30.0
172173

173174
SUBSCRIBE_STATES_MSG_TYPES = (*SUBSCRIBE_STATES_RESPONSE_TYPES, CameraImageResponse)
174175

@@ -1313,17 +1314,17 @@ def update_command(
13131314
UpdateCommandRequest(key=key, command=command, device_id=device_id)
13141315
)
13151316

1316-
def execute_service(
1317+
async def execute_service(
13171318
self,
13181319
service: UserService,
13191320
data: ExecuteServiceDataType,
13201321
*,
1321-
on_response: Callable[[ExecuteServiceResponseModel], None] | None = None,
1322-
return_response: bool = False,
1323-
) -> None:
1322+
return_response: bool | None = None,
1323+
timeout: float = DEFAULT_EXECUTE_SERVICE_TIMEOUT,
1324+
) -> ExecuteServiceResponseModel | None:
13241325
connection = self._get_connection()
13251326
# Generate call_id when response callback is provided
1326-
call_id = next(self._call_id_counter) if on_response is not None else 0
1327+
call_id = next(self._call_id_counter) if return_response is not None else 0
13271328
req = ExecuteServiceRequest(
13281329
key=service.key,
13291330
call_id=call_id,
@@ -1353,21 +1354,31 @@ def execute_service(
13531354
req.args.extend(args)
13541355

13551356
# Register callback for response if provided
1356-
if on_response is not None:
1357-
unsub: Callable[[], None] | None = None
1357+
if return_response is not None:
1358+
response_event = asyncio.Event()
1359+
response_msg: ExecuteServiceResponseModel | None = None
13581360

13591361
def _on_response(msg: ExecuteServiceResponse) -> None:
1362+
nonlocal response_msg
13601363
if msg.call_id == call_id:
1361-
on_response(ExecuteServiceResponseModel.from_pb(msg))
1362-
if unsub is not None:
1363-
unsub()
1364+
response_msg = ExecuteServiceResponseModel.from_pb(msg)
1365+
response_event.set()
13641366

13651367
unsub = connection.add_message_callback(
13661368
_on_response,
13671369
(ExecuteServiceResponse,),
13681370
)
13691371

1370-
connection.send_message(req)
1372+
try:
1373+
connection.send_message(req)
1374+
await asyncio.wait_for(response_event.wait(), timeout=timeout)
1375+
return response_msg
1376+
finally:
1377+
unsub()
1378+
else:
1379+
connection.send_message(req)
1380+
1381+
return None
13711382

13721383
def _request_image(self, *, single: bool = False, stream: bool = False) -> None:
13731384
self._get_connection().send_message(

tests/test_client.py

Lines changed: 68 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@
122122
ClimateSwingMode,
123123
DeviceInfo,
124124
ESPHomeBluetoothGATTServices,
125-
ExecuteServiceResponse as ExecuteServiceResponseModel,
126125
FanDirection,
127126
FanSpeed,
128127
HomeassistantActionResponse as HomeassistantActionResponseModel,
@@ -963,9 +962,9 @@ async def test_execute_service(auth_client: APIClient) -> None:
963962
)
964963

965964
with pytest.raises(KeyError):
966-
auth_client.execute_service(service, data={})
965+
await auth_client.execute_service(service, data={})
967966

968-
auth_client.execute_service(
967+
await auth_client.execute_service(
969968
service,
970969
data={
971970
"arg1": False,
@@ -1006,7 +1005,7 @@ async def test_execute_service(auth_client: APIClient) -> None:
10061005
)
10071006

10081007
# Test legacy_int
1009-
auth_client.execute_service(
1008+
await auth_client.execute_service(
10101009
service,
10111010
data={
10121011
"arg1": False,
@@ -1025,7 +1024,7 @@ async def test_execute_service(auth_client: APIClient) -> None:
10251024
send.reset_mock()
10261025

10271026
# Test arg order
1028-
auth_client.execute_service(
1027+
await auth_client.execute_service(
10291028
service,
10301029
data={
10311030
"arg2": 42,
@@ -1045,7 +1044,7 @@ async def test_execute_service(auth_client: APIClient) -> None:
10451044

10461045

10471046
async def test_execute_service_with_call_id(auth_client: APIClient) -> None:
1048-
"""Test that call_id is auto-generated when on_response is provided."""
1047+
"""Test that call_id is auto-generated when return_response is set."""
10491048
send = patch_send(auth_client)
10501049
patch_api_version(auth_client, APIVersion(1, 3))
10511050

@@ -1057,51 +1056,45 @@ async def test_execute_service_with_call_id(auth_client: APIClient) -> None:
10571056
],
10581057
)
10591058

1060-
def dummy_callback(response: ExecuteServiceResponseModel) -> None:
1061-
pass
1062-
1063-
# Test without on_response - call_id should be 0
1064-
auth_client.execute_service(
1059+
# Test without return_response - call_id should be 0
1060+
await auth_client.execute_service(
10651061
service,
10661062
data={"arg1": True},
10671063
)
1068-
send.assert_called_once_with(
1069-
ExecuteServiceRequest(
1070-
key=1,
1071-
args=[
1072-
ExecuteServiceArgument(bool_=True),
1073-
],
1074-
call_id=0,
1075-
return_response=False,
1076-
)
1077-
)
1064+
req = send.call_args[0][0]
1065+
assert req.call_id == 0
10781066
send.reset_mock()
10791067

1080-
# Test with on_response - call_id should be auto-generated (non-zero)
1081-
auth_client.execute_service(
1082-
service,
1083-
data={"arg1": False},
1084-
on_response=dummy_callback,
1085-
)
1068+
# Test with return_response=True - call_id should be auto-generated (non-zero)
1069+
# Use short timeout since no response will come, we just want to verify the request
1070+
with pytest.raises(asyncio.TimeoutError):
1071+
await auth_client.execute_service(
1072+
service,
1073+
data={"arg1": False},
1074+
return_response=True,
1075+
timeout=0.01,
1076+
)
10861077
req = send.call_args[0][0]
10871078
assert req.call_id != 0 # Auto-generated
10881079
first_call_id = req.call_id
10891080
send.reset_mock()
10901081

10911082
# Test that call_id increments
1092-
auth_client.execute_service(
1093-
service,
1094-
data={"arg1": True},
1095-
on_response=dummy_callback,
1096-
)
1083+
with pytest.raises(asyncio.TimeoutError):
1084+
await auth_client.execute_service(
1085+
service,
1086+
data={"arg1": True},
1087+
return_response=True,
1088+
timeout=0.01,
1089+
)
10971090
req = send.call_args[0][0]
10981091
assert req.call_id == first_call_id + 1
10991092

11001093

11011094
async def test_execute_service_return_response_combinations(
11021095
auth_client: APIClient,
11031096
) -> None:
1104-
"""Test that return_response is passed through and call_id is generated for on_response."""
1097+
"""Test return_response behavior and call_id generation."""
11051098
send = patch_send(auth_client)
11061099
patch_api_version(auth_client, APIVersion(1, 3))
11071100

@@ -1111,32 +1104,26 @@ async def test_execute_service_return_response_combinations(
11111104
args=[],
11121105
)
11131106

1114-
def dummy_callback(response: ExecuteServiceResponseModel) -> None:
1115-
pass
1116-
1117-
# Case 1: no callback, no return_response -> call_id=0, return_response=False
1118-
auth_client.execute_service(service, data={})
1119-
assert send.call_args[0][0].return_response is False
1107+
# Case 1: return_response=None (default) -> call_id=0, no waiting
1108+
await auth_client.execute_service(service, data={})
11201109
assert send.call_args[0][0].call_id == 0
11211110
send.reset_mock()
11221111

1123-
# Case 2: return_response=True without callback -> call_id=0, return_response=True
1124-
auth_client.execute_service(service, data={}, return_response=True)
1112+
# Case 2: return_response=True -> generates call_id, waits for response
1113+
with pytest.raises(asyncio.TimeoutError):
1114+
await auth_client.execute_service(
1115+
service, data={}, return_response=True, timeout=0.01
1116+
)
11251117
assert send.call_args[0][0].return_response is True
1126-
assert send.call_args[0][0].call_id == 0
1127-
send.reset_mock()
1128-
1129-
# Case 3: on_response generates call_id, return_response passed through
1130-
auth_client.execute_service(service, data={}, on_response=dummy_callback)
1131-
assert send.call_args[0][0].return_response is False
11321118
assert send.call_args[0][0].call_id != 0
11331119
send.reset_mock()
11341120

1135-
# Case 4: on_response with return_response=True
1136-
auth_client.execute_service(
1137-
service, data={}, on_response=dummy_callback, return_response=True
1138-
)
1139-
assert send.call_args[0][0].return_response is True
1121+
# Case 3: return_response=False -> generates call_id, waits for response
1122+
with pytest.raises(asyncio.TimeoutError):
1123+
await auth_client.execute_service(
1124+
service, data={}, return_response=False, timeout=0.01
1125+
)
1126+
assert send.call_args[0][0].return_response is False
11401127
assert send.call_args[0][0].call_id != 0
11411128
send.reset_mock()
11421129

@@ -2249,15 +2236,14 @@ def on_zwave_proxy_request(msg: ZWaveProxyRequest) -> None:
22492236
assert first_msg.data == b"\x00\x01\x02\x03"
22502237

22512238

2252-
async def test_execute_service_with_response_callback(
2239+
async def test_execute_service_with_response(
22532240
api_client: tuple[
22542241
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
22552242
],
22562243
) -> None:
2257-
"""Test execute_service with on_response callback."""
2244+
"""Test execute_service with return_response returns response directly."""
22582245
client, connection, _transport, protocol = api_client
22592246
patch_api_version(client, APIVersion(1, 3))
2260-
test_msg: list[ExecuteServiceResponseModel] = []
22612247
sent_requests: list[ExecuteServiceRequest] = []
22622248

22632249
# Capture sent requests to get auto-generated call_id
@@ -2270,9 +2256,6 @@ def capture_send(msg: Any) -> None:
22702256

22712257
connection.send_message = capture_send
22722258

2273-
def on_response(msg: ExecuteServiceResponseModel) -> None:
2274-
test_msg.append(msg)
2275-
22762259
service = UserService(
22772260
name="my_service",
22782261
key=1,
@@ -2281,13 +2264,15 @@ def on_response(msg: ExecuteServiceResponseModel) -> None:
22812264
],
22822265
)
22832266

2284-
# Execute service with callback
2285-
client.execute_service(
2286-
service,
2287-
data={"arg1": True},
2288-
on_response=on_response,
2267+
# Execute service with return_response - start as task so we can simulate response
2268+
task = asyncio.create_task(
2269+
client.execute_service(
2270+
service,
2271+
data={"arg1": True},
2272+
return_response=True,
2273+
)
22892274
)
2290-
await asyncio.sleep(0)
2275+
await asyncio.sleep(0) # Let task start and send request
22912276

22922277
# Get the auto-generated call_id
22932278
assert len(sent_requests) == 1
@@ -2302,31 +2287,22 @@ def on_response(msg: ExecuteServiceResponseModel) -> None:
23022287
response_data=b'{"result": "ok"}',
23032288
)
23042289
mock_data_received(protocol, generate_plaintext_packet(response))
2290+
result = await task # Task should complete now that response was received
23052291

2306-
assert len(test_msg) == 1
2307-
first_msg = test_msg[0]
2308-
assert first_msg.call_id == first_call_id
2309-
assert first_msg.success is True
2310-
assert first_msg.error_message == ""
2311-
assert first_msg.response_data == b'{"result": "ok"}'
2292+
assert result is not None
2293+
assert result.call_id == first_call_id
2294+
assert result.success is True
2295+
assert result.error_message == ""
2296+
assert result.response_data == b'{"result": "ok"}'
23122297

2313-
# Callback should auto-unsubscribe, so another response shouldn't be received
2314-
response2: message.Message = ExecuteServiceResponsePb(
2315-
call_id=first_call_id,
2316-
success=True,
2317-
error_message="",
2318-
response_data=b'{"result": "second"}',
2319-
)
2320-
mock_data_received(protocol, generate_plaintext_packet(response2))
2321-
assert len(test_msg) == 1 # Still only one message
2322-
2323-
# Test that responses with different call_id are ignored
2324-
test_msg.clear()
2298+
# Test that responses with different call_id are ignored until correct one arrives
23252299
sent_requests.clear()
2326-
client.execute_service(
2327-
service,
2328-
data={"arg1": False},
2329-
on_response=on_response,
2300+
task2 = asyncio.create_task(
2301+
client.execute_service(
2302+
service,
2303+
data={"arg1": False},
2304+
return_response=True,
2305+
)
23302306
)
23312307
await asyncio.sleep(0)
23322308

@@ -2343,7 +2319,8 @@ def on_response(msg: ExecuteServiceResponseModel) -> None:
23432319
response_data=b"",
23442320
)
23452321
mock_data_received(protocol, generate_plaintext_packet(wrong_response))
2346-
assert len(test_msg) == 0
2322+
await asyncio.sleep(0)
2323+
assert not task2.done() # Task still waiting
23472324

23482325
# Correct call_id should be received
23492326
correct_response: message.Message = ExecuteServiceResponsePb(
@@ -2353,8 +2330,9 @@ def on_response(msg: ExecuteServiceResponseModel) -> None:
23532330
response_data=b"",
23542331
)
23552332
mock_data_received(protocol, generate_plaintext_packet(correct_response))
2356-
assert len(test_msg) == 1
2357-
assert test_msg[0].call_id == second_call_id
2333+
result2 = await task2 # Task should complete now
2334+
assert result2 is not None
2335+
assert result2.call_id == second_call_id
23582336

23592337

23602338
async def test_subscribe_service_calls(auth_client: APIClient) -> None:
@@ -3368,7 +3346,7 @@ async def test_calls_after_connection_closed(
33683346
args=[],
33693347
)
33703348
with pytest.raises(APIConnectionError):
3371-
client.execute_service(service, {})
3349+
await client.execute_service(service, {})
33723350
for method in (
33733351
client.button_command,
33743352
client.climate_command,

0 commit comments

Comments
 (0)