diff --git a/src/a2a/utils/task.py b/src/a2a/utils/task.py index b57da9b8..60272367 100644 --- a/src/a2a/utils/task.py +++ b/src/a2a/utils/task.py @@ -18,7 +18,7 @@ def new_task(request: Message) -> Task: Raises: TypeError: If the message role is None. - ValueError: If the message parts are empty or if any part has empty content. + ValueError: If the message parts are empty, if any part has empty content, or if the provided context_id is invalid. """ if not request.role: raise TypeError('Message role cannot be None') @@ -28,12 +28,22 @@ def new_task(request: Message) -> Task: if isinstance(part.root, TextPart) and not part.root.text: raise ValueError('TextPart content cannot be empty') + context_id_str = request.context_id + if context_id_str is not None: + try: + uuid.UUID(context_id_str) + context_id = context_id_str + except (ValueError, AttributeError, TypeError) as e: + raise ValueError( + f"Invalid context_id: '{context_id_str}' is not a valid UUID." + ) from e + else: + context_id = str(uuid.uuid4()) + return Task( status=TaskStatus(state=TaskState.submitted), id=(request.task_id if request.task_id else str(uuid.uuid4())), - context_id=( - request.context_id if request.context_id else str(uuid.uuid4()) - ), + context_id=context_id, history=[request], ) diff --git a/tests/utils/test_task.py b/tests/utils/test_task.py index cb3dc386..77441316 100644 --- a/tests/utils/test_task.py +++ b/tests/utils/test_task.py @@ -188,6 +188,24 @@ def test_completed_task_invalid_artifact_type(self): history=[], ) + def test_new_task_with_invalid_context_id(self): + """Test that new_task raises a ValueError for various invalid context_id formats.""" + invalid_ids = ['not-a-uuid', ''] + for invalid_id in invalid_ids: + with self.subTest(invalid_id=invalid_id): + with pytest.raises( + ValueError, + match=f"Invalid context_id: '{invalid_id}' is not a valid UUID.", + ): + new_task( + Message( + role=Role.user, + parts=[Part(root=TextPart(text='test message'))], + message_id=str(uuid.uuid4()), + context_id=invalid_id, + ) + ) + if __name__ == '__main__': unittest.main()