Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 52 additions & 32 deletions tests/test_communication/mock_utransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SPDX-License-Identifier: Apache-2.0
"""

import asyncio
import threading
from abc import ABC
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -35,6 +36,8 @@
from uprotocol.validation.validationresult import ValidationResult


# ToDo Change the implementation of transport APIs to use the URI match pattern and save listeners
# against the source and sink filter tuple.
class MockUTransport(UTransport):
def get_source(self) -> UUri:
return self.source
Expand All @@ -53,7 +56,7 @@ def build_response(self, request: UMessage):
def close(self):
self.listeners.clear()

def register_listener(self, source_filter: UUri, listener: UListener, sink_filter: UUri = None) -> UStatus:
async def register_listener(self, source_filter: UUri, listener: UListener, sink_filter: UUri = None) -> UStatus:
with self.lock:
if sink_filter is not None: # method uri
topic = UriSerializer().serialize(sink_filter)
Expand All @@ -65,7 +68,7 @@ def register_listener(self, source_filter: UUri, listener: UListener, sink_filte
self.listeners[topic].append(listener)
return UStatus(code=UCode.OK)

def unregister_listener(self, source: UUri, listener: UListener, sink: UUri = None) -> UStatus:
async def unregister_listener(self, source: UUri, listener: UListener, sink: UUri = None) -> UStatus:
with self.lock:
if sink is not None: # method uri
topic = UriSerializer().serialize(sink)
Expand All @@ -82,57 +85,74 @@ def unregister_listener(self, source: UUri, listener: UListener, sink: UUri = No
result = UStatus(code=code)
return result

def send(self, message: UMessage) -> UStatus:
async def send(self, message: UMessage) -> UStatus:
validator = UAttributesValidator.get_validator(message.attributes)
with self.lock:
if message is None or validator.validate(message.attributes) != ValidationResult.success():
return UStatus(code=UCode.INVALID_ARGUMENT, message="Invalid message attributes")

executor = ThreadPoolExecutor(max_workers=5)
executor.submit(self._notify_listeners, message)
# Use a ThreadPoolExecutor with max_workers=1
executor = ThreadPoolExecutor(max_workers=1)

try:
# Submit _notify_listeners to the executor
future = executor.submit(self._notify_listeners, message)

# Await completion of the Future
await asyncio.wrap_future(future)

finally:
# Clean up the executor
executor.shutdown()

return UStatus(code=UCode.OK)

def _notify_listeners(self, umsg):
if umsg.attributes.type == UMessageType.UMESSAGE_TYPE_PUBLISH:
for key, listeners in self.listeners.items():
uri = UriSerializer().deserialize(key)
if not (UriValidator.is_rpc_method(uri) or UriValidator.is_rpc_response(uri)):
for listener in listeners:
listener.on_receive(umsg)

else:
if umsg.attributes.type == UMessageType.UMESSAGE_TYPE_REQUEST:
serialized_uri = UriSerializer().serialize(umsg.attributes.sink)
if serialized_uri not in self.listeners:
# no listener registered for method uri, send dummy response.
# This case will only come for request type
# as for response type, there will always be response handler as it is in up client
serialized_uri = UriSerializer().serialize(UriFactory.ANY)
umsg = self.build_response(umsg)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
if umsg.attributes.type == UMessageType.UMESSAGE_TYPE_PUBLISH:
for key, listeners in self.listeners.items():
uri = UriSerializer().deserialize(key)
if not (UriValidator.is_rpc_method(uri) or UriValidator.is_rpc_response(uri)):
for listener in listeners:
loop.call_soon_threadsafe(listener.on_receive, umsg)

else:
# this is for response type,handle response
serialized_uri = UriSerializer().serialize(UriFactory.ANY)
if umsg.attributes.type == UMessageType.UMESSAGE_TYPE_REQUEST:
serialized_uri = UriSerializer().serialize(umsg.attributes.sink)
if serialized_uri not in self.listeners:
# no listener registered for method uri, send dummy response.
# This case will only come for request type
# as for response type, there will always be response handler as it is in up client
serialized_uri = UriSerializer().serialize(UriFactory.ANY)
umsg = self.build_response(umsg)
else:
# this is for response type,handle response
serialized_uri = UriSerializer().serialize(UriFactory.ANY)

if serialized_uri in self.listeners:
for listener in self.listeners[serialized_uri]:
listener.on_receive(umsg)
break # as there will be only one listener for method uri
if serialized_uri in self.listeners:
for listener in self.listeners[serialized_uri]:
loop.call_soon_threadsafe(listener.on_receive, umsg)
break # as there will be only one listener for method uri
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()


class TimeoutUTransport(MockUTransport, ABC):
def send(self, message):
async def send(self, message):
return UStatus(code=UCode.OK)


class ErrorUTransport(MockUTransport, ABC):
def send(self, message):
async def send(self, message):
return UStatus(code=UCode.FAILED_PRECONDITION)

def register_listener(self, source_filter: UUri, listener: UListener, sink_filter: UUri = None) -> UStatus:
async def register_listener(self, source_filter: UUri, listener: UListener, sink_filter: UUri = None) -> UStatus:
return UStatus(code=UCode.FAILED_PRECONDITION)

def unregister_listener(self, source: UUri, listener: UListener, sink: UUri = None) -> UStatus:
async def unregister_listener(self, source: UUri, listener: UListener, sink: UUri = None) -> UStatus:
return UStatus(code=UCode.FAILED_PRECONDITION)


Expand All @@ -150,7 +170,7 @@ class EchoUTransport(MockUTransport):
def build_response(self, request):
return request

def send(self, message):
async def send(self, message):
response = self.build_response(message)
executor = ThreadPoolExecutor(max_workers=1)
executor.submit(self._notify_listeners, response)
Expand Down
11 changes: 8 additions & 3 deletions tests/test_communication/test_inmemoryrpcclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,21 @@ async def test_invoke_method_with_payload(self):
rpc_client = InMemoryRpcClient(MockUTransport())
response = await rpc_client.invoke_method(self.create_method_uri(), payload, None)
self.assertIsNotNone(response)
self.assertEqual(response, payload)

async def test_invoke_method_with_payload_and_call_options(self):
payload = UPayload.pack_to_any(UUri())
options = CallOptions(2000, UPriority.UPRIORITY_CS5)
rpc_client = InMemoryRpcClient(MockUTransport())
response = await rpc_client.invoke_method(self.create_method_uri(), payload, options)
self.assertIsNotNone(response)
self.assertEqual(response, payload)

async def test_invoke_method_with_null_payload(self):
rpc_client = InMemoryRpcClient(MockUTransport())
response = await rpc_client.invoke_method(self.create_method_uri(), None, CallOptions.DEFAULT)
self.assertIsNotNone(response)
self.assertEqual(response, UPayload.EMPTY)

async def test_invoke_method_with_timeout_transport(self):
payload = UPayload.pack_to_any(UUri())
Expand All @@ -64,7 +67,8 @@ async def test_invoke_method_with_multi_invoke_transport(self):
response2 = await rpc_client.invoke_method(self.create_method_uri(), payload, None)
self.assertIsNotNone(response1)
self.assertIsNotNone(response2)
self.assertEqual(response1, response2)
self.assertEqual(payload, response1)
self.assertEqual(payload, response2)

async def test_close_with_multiple_listeners(self):
rpc_client = InMemoryRpcClient(MockUTransport())
Expand All @@ -74,7 +78,8 @@ async def test_close_with_multiple_listeners(self):
response2 = await rpc_client.invoke_method(self.create_method_uri(), payload, None)
self.assertIsNotNone(response1)
self.assertIsNotNone(response2)
self.assertEqual(response1, response2)
self.assertEqual(payload, response1)
self.assertEqual(payload, response2)
rpc_client.close()

async def test_invoke_method_with_comm_status_transport(self):
Expand All @@ -87,7 +92,7 @@ async def test_invoke_method_with_comm_status_transport(self):

async def test_invoke_method_with_error_transport(self):
class ErrorUTransport(MockUTransport):
def send(self, message):
async def send(self, message):
return UStatus(code=UCode.FAILED_PRECONDITION)

rpc_client = InMemoryRpcClient(ErrorUTransport())
Expand Down
62 changes: 31 additions & 31 deletions tests/test_communication/test_inmemoryrpcserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from uprotocol.v1.ustatus_pb2 import UStatus


class TestInMemoryRpcServer(unittest.TestCase):
class TestInMemoryRpcServer(unittest.IsolatedAsyncioTestCase):
@staticmethod
def create_method_uri():
return UUri(authority_name="Neelam", ue_id=4, ue_version_major=1, resource_id=3)
Expand All @@ -48,64 +48,64 @@ def test_constructor_transport_not_instance(self):
InMemoryRpcServer("Invalid Transport")
self.assertEqual(str(context.exception), UTransport.TRANSPORT_NOT_INSTANCE_ERROR)

def test_register_request_handler_method_uri_none(self):
async def test_register_request_handler_method_uri_none(self):
server = InMemoryRpcServer(MockUTransport())
handler = MagicMock(return_value=UPayload.EMPTY)

with self.assertRaises(ValueError) as context:
server.register_request_handler(None, handler)
await server.register_request_handler(None, handler)
self.assertEqual(str(context.exception), "Method URI missing")

def test_register_request_handler_handler_none(self):
async def test_register_request_handler_handler_none(self):
server = InMemoryRpcServer(MockUTransport())
with self.assertRaises(ValueError) as context:
server.register_request_handler(self.create_method_uri(), None)
await server.register_request_handler(self.create_method_uri(), None)
self.assertEqual(str(context.exception), "Request listener missing")

def test_unregister_request_handler_method_uri_none(self):
async def test_unregister_request_handler_method_uri_none(self):
server = InMemoryRpcServer(MockUTransport())
handler = MagicMock(return_value=UPayload.EMPTY)

with self.assertRaises(ValueError) as context:
server.unregister_request_handler(None, handler)
await server.unregister_request_handler(None, handler)
self.assertEqual(str(context.exception), "Method URI missing")

def test_unregister_request_handler_handler_none(self):
async def test_unregister_request_handler_handler_none(self):
server = InMemoryRpcServer(MockUTransport())
with self.assertRaises(ValueError) as context:
server.unregister_request_handler(self.create_method_uri(), None)
await server.unregister_request_handler(self.create_method_uri(), None)
self.assertEqual(str(context.exception), "Request listener missing")

def test_registering_request_listener(self):
async def test_registering_request_listener(self):
handler = MagicMock(return_value=UPayload.EMPTY)
method = self.create_method_uri()
server = InMemoryRpcServer(MockUTransport())
self.assertEqual(server.register_request_handler(method, handler).code, UCode.OK)
self.assertEqual(server.unregister_request_handler(method, handler).code, UCode.OK)
self.assertEqual((await server.register_request_handler(method, handler)).code, UCode.OK)
self.assertEqual((await server.unregister_request_handler(method, handler)).code, UCode.OK)

def test_registering_twice_the_same_request_handler(self):
async def test_registering_twice_the_same_request_handler(self):
handler = MagicMock(return_value=UPayload.EMPTY)
server = InMemoryRpcServer(MockUTransport())
status = server.register_request_handler(self.create_method_uri(), handler)
status = await server.register_request_handler(self.create_method_uri(), handler)
self.assertEqual(status.code, UCode.OK)
status = server.register_request_handler(self.create_method_uri(), handler)
status = await server.register_request_handler(self.create_method_uri(), handler)
self.assertEqual(status.code, UCode.ALREADY_EXISTS)

def test_unregistering_non_registered_request_handler(self):
async def test_unregistering_non_registered_request_handler(self):
handler = MagicMock(side_effect=NotImplementedError("Unimplemented method 'handleRequest'"))
server = InMemoryRpcServer(MockUTransport())
status = server.unregister_request_handler(self.create_method_uri(), handler)
status = await server.unregister_request_handler(self.create_method_uri(), handler)
self.assertEqual(status.code, UCode.NOT_FOUND)

def test_registering_request_listener_with_error_transport(self):
async def test_registering_request_listener_with_error_transport(self):
handler = MagicMock(return_value=UPayload.EMPTY)
server = InMemoryRpcServer(ErrorUTransport())
status = server.register_request_handler(self.create_method_uri(), handler)
status = await server.register_request_handler(self.create_method_uri(), handler)
self.assertEqual(status.code, UCode.FAILED_PRECONDITION)

def test_handle_requests(self):
async def test_handle_requests(self):
class CustomTestUTransport(MockUTransport):
def send(self, message):
async def send(self, message):
serialized_uri = UriSerializer().serialize(message.attributes.sink)
if serialized_uri in self.listeners:
for listener in self.listeners[serialized_uri]:
Expand All @@ -120,18 +120,18 @@ def send(self, message):
# Update the resource_id
method2.resource_id = 69

self.assertEqual(server.register_request_handler(method, handler).code, UCode.OK)
self.assertEqual((await server.register_request_handler(method, handler)).code, UCode.OK)

request = UMessageBuilder.request(transport.get_source(), method2, 1000).build()

# fake sending a request message that will trigger the handler to be called but since it is
# not for the same method as the one registered, it should be ignored and the handler not called
self.assertEqual(transport.send(request).code, UCode.OK)
self.assertEqual((await transport.send(request)).code, UCode.OK)

def test_handle_requests_exception(self):
async def test_handle_requests_exception(self):
# test transport that will trigger the handleRequest()
class CustomTestUTransport(MockUTransport):
def send(self, message):
async def send(self, message):
serialized_uri = UriSerializer().serialize(message.attributes.sink)
if serialized_uri in self.listeners:
for listener in self.listeners[serialized_uri]:
Expand All @@ -148,14 +148,14 @@ def handle_request(self, message: UMessage) -> UPayload:
server = InMemoryRpcServer(transport)
method = self.create_method_uri()

self.assertEqual(server.register_request_handler(method, handler).code, UCode.OK)
self.assertEqual((await server.register_request_handler(method, handler)).code, UCode.OK)

request = UMessageBuilder.request(transport.get_source(), method, 1000).build()
self.assertEqual(transport.send(request).code, UCode.OK)
self.assertEqual((await transport.send(request)).code, UCode.OK)

def test_handle_requests_unknown_exception(self):
async def test_handle_requests_unknown_exception(self):
class CustomTestUTransport(MockUTransport):
def send(self, message):
async def send(self, message):
serialized_uri = UriSerializer().serialize(message.attributes.sink)
if serialized_uri in self.listeners:
for listener in self.listeners[serialized_uri]:
Expand All @@ -172,10 +172,10 @@ def handle_request(self, message: UMessage) -> UPayload:
server = InMemoryRpcServer(transport)
method = self.create_method_uri()

self.assertEqual(server.register_request_handler(method, handler).code, UCode.OK)
self.assertEqual((await server.register_request_handler(method, handler)).code, UCode.OK)

request = UMessageBuilder.request(transport.get_source(), method, 1000).build()
self.assertEqual(transport.send(request).code, UCode.OK)
self.assertEqual((await transport.send(request)).code, UCode.OK)


if __name__ == '__main__':
Expand Down
10 changes: 5 additions & 5 deletions tests/test_communication/test_inmemorysubscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def test_unregister_listener(self):
subscription_response = await subscriber.subscribe(topic, self.listener, CallOptions())
self.assertFalse(subscription_response is None)

status = subscriber.unregister_listener(topic, self.listener)
status = await subscriber.unregister_listener(topic, self.listener)
self.assertEqual(status.code, UCode.OK)

async def test_unsubscribe_with_commstatus_error(self):
Expand All @@ -91,19 +91,19 @@ async def test_unsubscribe_with_exception(self):
self.assertEqual(response.message, "Request timed out")
self.assertEqual(response.code, UCode.DEADLINE_EXCEEDED)

def test_unregister_listener_missing_topic(self):
async def test_unregister_listener_missing_topic(self):
transport = TimeoutUTransport()
subscriber = InMemorySubscriber(transport, InMemoryRpcClient(transport))
with self.assertRaises(ValueError) as context:
subscriber.unregister_listener(None, self.listener)
await subscriber.unregister_listener(None, self.listener)
self.assertEqual(str(context.exception), "Unsubscribe topic missing")

def test_unregister_listener_missing_listener(self):
async def test_unregister_listener_missing_listener(self):
topic = self.create_topic()
transport = TimeoutUTransport()
subscriber = InMemorySubscriber(transport, InMemoryRpcClient(transport))
with self.assertRaises(ValueError) as context:
subscriber.unregister_listener(topic, None)
await subscriber.unregister_listener(topic, None)
self.assertEqual(str(context.exception), "Request listener missing")

async def test_unsubscribe_missing_topic(self):
Expand Down
Loading