diff --git a/tests/test_transport/test_validate/test_uattributesvalidator.py b/tests/test_transport/test_validate/test_uattributesvalidator.py index 4ada52d..672e3a5 100644 --- a/tests/test_transport/test_validate/test_uattributesvalidator.py +++ b/tests/test_transport/test_validate/test_uattributesvalidator.py @@ -466,6 +466,39 @@ def test_validating_request_containing_token(self): status = validator.validate(attributes) self.assertEqual(ValidationResult.success(), status) + def test_valid_request_methoduri_in_sink(self): + sink = LongUriSerializer().deserialize("/test.service/1/rpc.method") + attributes = UAttributesBuilder.request(UPriority.UPRIORITY_CS0, sink, 3000).build() + validator = UAttributesValidator.get_validator(attributes) + self.assertEqual("UAttributesValidator.Request", str(validator)) + status = validator.validate(attributes) + self.assertEqual(ValidationResult.success(), status) + + def test_invalid_request_methoduri_in_sink(self): + sink = LongUriSerializer().deserialize("/test.client/1/test.response") + attributes = UAttributesBuilder.request(UPriority.UPRIORITY_CS0, sink, 3000).build() + validator = UAttributesValidator.get_validator(attributes) + self.assertEqual("UAttributesValidator.Request", str(validator)) + status = validator.validate(attributes) + self.assertEqual("Invalid RPC method uri. Uri should be the method to be called, or method from response.", status.get_message()) + + def test_valid_response_uri_in_sink(self): + sink = LongUriSerializer().deserialize("/test.client/1/rpc.response") + attributes = UAttributesBuilder.response(UPriority.UPRIORITY_CS0, sink, Factories.UPROTOCOL.create()).build() + validator = UAttributesValidator.get_validator(attributes) + self.assertEqual("UAttributesValidator.Response", str(validator)) + status = validator.validate(attributes) + self.assertEqual(ValidationResult.success(), status) + + def test_invalid_response_uri_in_sink(self): + sink = LongUriSerializer().deserialize("/test.client/1/rpc.method") + attributes = UAttributesBuilder.response(UPriority.UPRIORITY_CS0, sink, Factories.UPROTOCOL.create()).build() + validator = UAttributesValidator.get_validator(attributes) + self.assertEqual("UAttributesValidator.Response", str(validator)) + status = validator.validate(attributes) + self.assertEqual("Invalid RPC response type.", status.get_message()) + + if __name__ == '__main__': unittest.main() diff --git a/uprotocol/transport/validate/uattributesvalidator.py b/uprotocol/transport/validate/uattributesvalidator.py index 3017fb0..a92ffef 100644 --- a/uprotocol/transport/validate/uattributesvalidator.py +++ b/uprotocol/transport/validate/uattributesvalidator.py @@ -26,9 +26,8 @@ from abc import abstractmethod -from datetime import datetime from enum import Enum - +from uprotocol.proto.uri_pb2 import UUri from uprotocol.proto.uattributes_pb2 import UAttributes, UMessageType from uprotocol.proto.ustatus_pb2 import UCode from uprotocol.uri.validator.urivalidator import UriValidator @@ -221,7 +220,7 @@ def validate_sink(self, attributes_value: UAttributes) -> ValidationResult: @param attributes_value:UAttributes object containing the sink to validate. @return:Returns a ValidationResult that is success or failed with a failure message. """ - return UriValidator.validate_rpc_response( + return UriValidator.validate_rpc_method( attributes_value.sink) if attributes_value.HasField('sink') else ValidationResult.failure("Missing Sink") def validate_ttl(self, attributes_value: UAttributes) -> ValidationResult: @@ -263,11 +262,11 @@ def validate_sink(self, attributes_value: UAttributes) -> ValidationResult: @param attributes_value:UAttributes object containing the sink to validate. @return:Returns a ValidationResult that is success or failed with a failure message. """ - result = UriValidator.validate_rpc_method(attributes_value.sink) - if result.is_success(): - return result - else: + if not attributes_value.HasField('sink') or attributes_value.sink == UUri(): return ValidationResult.failure("Missing Sink") + result = UriValidator.validate_rpc_response(attributes_value.sink) + return result + def validate_req_id(self, attributes_value: UAttributes) -> ValidationResult: """