From 597d63ea6c0d87c4746667d893e595c4a642d8e5 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Sat, 14 Jun 2025 16:16:19 -0400 Subject: [PATCH 01/17] Add support for generate protobuf compliant DESCRIPTORs - Fixes https://github.com/danielgtaylor/python-betterproto/issues/443 - Fixes https://github.com/betterproto/python-betterproto2/issues/70 --- .github/workflows/ci.yml | 2 +- .gitignore | 2 +- betterproto2/docs/descriptors.md | 5 + betterproto2/pyproject.toml | 1 + betterproto2/src/betterproto2/__init__.py | 2 +- .../tests/grpc/test_grpclib_reflection.py | 74 +++++++++ .../grpc/test_message_enum_descriptors.py | 17 ++ betterproto2_compiler/pyproject.toml | 14 ++ .../betterproto2_compiler/plugin/models.py | 12 ++ .../betterproto2_compiler/plugin/parser.py | 1 + .../src/betterproto2_compiler/settings.py | 1 + .../templates/header.py.j2 | 8 + .../templates/template.py.j2 | 14 ++ betterproto2_compiler/tests/generate.py | 25 ++- .../grpc_reflection_v1/reflection.proto | 146 ++++++++++++++++++ betterproto2_compiler/tests/util.py | 18 ++- 16 files changed, 331 insertions(+), 11 deletions(-) create mode 100644 betterproto2/docs/descriptors.md create mode 100644 betterproto2/tests/grpc/test_grpclib_reflection.py create mode 100644 betterproto2/tests/grpc/test_message_enum_descriptors.py create mode 100644 betterproto2_compiler/tests/inputs/grpc_reflection_v1/reflection.proto diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9c53af51..62863b54 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -49,7 +49,7 @@ jobs: - name: Move compiled files to betterproto2 shell: bash - run: mv betterproto2_compiler/tests/output_betterproto betterproto2_compiler/tests/output_betterproto_pydantic betterproto2_compiler/tests/output_reference betterproto2/tests + run: mv betterproto2_compiler/tests/output_betterproto betterproto2_compiler/tests/output_betterproto_pydantic betterproto2_compiler/tests/output_betterproto_descriptor betterproto2_compiler/tests/output_reference betterproto2/tests - name: Execute test suite working-directory: ./betterproto2 diff --git a/.gitignore b/.gitignore index de01ba67..442c3f77 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ .pytest_cache .python-version build/ -tests/output_* +*/tests/output_* **/__pycache__ dist **/*.egg-info diff --git a/betterproto2/docs/descriptors.md b/betterproto2/docs/descriptors.md new file mode 100644 index 00000000..38439300 --- /dev/null +++ b/betterproto2/docs/descriptors.md @@ -0,0 +1,5 @@ +# Google Protobuf Descriptors + +Google's protoc plugin for Python generated DESCRIPTOR fields that enable reflection capabilities in many libraries (e.g. grpc, grpclib, mcap). + +By default, betterproto2 doesn't generate these as it introduces a dependency on `protobuf`. If you're okay with this dependency and want to generate DESCRIPTORs, use the compiler option `python_betterproto2_opt=google_protobuf_descriptors`. diff --git a/betterproto2/pyproject.toml b/betterproto2/pyproject.toml index 41eb24f7..65798b8e 100644 --- a/betterproto2/pyproject.toml +++ b/betterproto2/pyproject.toml @@ -144,6 +144,7 @@ rm -rf tests/output_* && git clone https://github.com/betterproto/python-betterproto2-compiler --branch compiled-test-files --single-branch compiled_files && mv compiled_files/tests_betterproto tests/output_betterproto && mv compiled_files/tests_betterproto_pydantic tests/output_betterproto_pydantic && +mv compiled_files/tests_betterproto_pydantic tests/output_betterproto_descriptor && mv compiled_files/tests_reference tests/output_reference && rm -rf compiled_files """ diff --git a/betterproto2/src/betterproto2/__init__.py b/betterproto2/src/betterproto2/__init__.py index 239d379b..7ceda2f0 100644 --- a/betterproto2/src/betterproto2/__init__.py +++ b/betterproto2/src/betterproto2/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -__all__ = ["__version__", "check_compiler_version", "unwrap", "MessagePool", "validators"] +__all__ = ["__version__", "check_compiler_version", "classproperty", "unwrap", "MessagePool", "validators"] import dataclasses import enum as builtin_enum diff --git a/betterproto2/tests/grpc/test_grpclib_reflection.py b/betterproto2/tests/grpc/test_grpclib_reflection.py new file mode 100644 index 00000000..4ed70bc1 --- /dev/null +++ b/betterproto2/tests/grpc/test_grpclib_reflection.py @@ -0,0 +1,74 @@ +import asyncio +from typing import Generic, TypeVar + +import grpclib +from grpclib.reflection.service import ServerReflection +import pytest +from grpclib.testing import ChannelFor +from google.protobuf import descriptor_pb2 + +from tests.grpc.async_channel import AsyncChannel +from tests.output_betterproto.grpc.reflection.v1 import ErrorResponse, ListServiceResponse, ServiceResponse, ServerReflectionRequest, ServerReflectionStub + +from tests.output_betterproto.example_service import TestBase + +class TestService(TestBase): + pass + +T = TypeVar("T") +class AsyncIterableQueue(Generic[T]): + def __init__(self): + self._queue = asyncio.Queue() + self._done = asyncio.Event() + + def put(self, item: T): + self._queue.put_nowait(item) + + def close(self): + self._queue.shutdown() + + def __aiter__(self): + return self + + async def __anext__(self) -> T: + try: + return await self._queue.get() + except asyncio.QueueShutDown: + raise StopAsyncIteration + +@pytest.mark.asyncio +async def test_grpclib_reflection(): + service = TestService() + services = ServerReflection.extend([service]) + async with ChannelFor(services) as channel: + requests = AsyncIterableQueue[ServerReflectionRequest]() + responses = ServerReflectionStub(channel).server_reflection_info(requests) + + # list services + requests.put(ServerReflectionRequest(list_services="")) + response = await anext(responses) + assert response.list_services_response == ListServiceResponse( + service=[ServiceResponse(name='example_service.Test')]) + + # list methods + + # should fail before we've added descriptors to the protobuf pool + requests.put(ServerReflectionRequest(file_containing_symbol="example_service.Test")) + response = await anext(responses) + assert response.error_response == ErrorResponse(error_code=5, error_message='not found') + assert response.file_descriptor_response is None + + # now it should work + import tests.output_betterproto_descriptor.example_service as example_service_with_desc + requests.put(ServerReflectionRequest(file_containing_symbol="example_service.Test")) + response = await anext(responses) + expected = descriptor_pb2.FileDescriptorProto.FromString(example_service_with_desc.DESCRIPTOR.serialized_pb) + assert response.error_response is None + assert response.file_descriptor_response is not None + assert len(response.file_descriptor_response.file_descriptor_proto) == 1 + actual = descriptor_pb2.FileDescriptorProto.FromString(response.file_descriptor_response.file_descriptor_proto[0]) + assert actual == expected + + requests.close() + + await anext(responses, None) diff --git a/betterproto2/tests/grpc/test_message_enum_descriptors.py b/betterproto2/tests/grpc/test_message_enum_descriptors.py new file mode 100644 index 00000000..100c71b7 --- /dev/null +++ b/betterproto2/tests/grpc/test_message_enum_descriptors.py @@ -0,0 +1,17 @@ +import pytest + +from tests.output_betterproto.service import ThingType, DoThingRequest +from tests.output_betterproto_descriptor.service import ThingType as ThingTypeWithDesc, DoThingRequest as DoThingRequestWithDesc + +def test_message_enum_descriptors(): + # Normally descriptors are not available as they require protobuf support + # to inteoperate with other libraries. + with pytest.raises(AttributeError): + ThingType.DESCRIPTOR.full_name + with pytest.raises(AttributeError): + DoThingRequest.DESCRIPTOR.full_name + + # But the python_betterproto2_opt=google_protobuf_descriptors option + # will add them in as long as protobuf is depended on. + assert ThingTypeWithDesc.DESCRIPTOR.full_name == "service.ThingType" + assert DoThingRequestWithDesc.DESCRIPTOR.full_name == "service.DoThingRequest" diff --git a/betterproto2_compiler/pyproject.toml b/betterproto2_compiler/pyproject.toml index 28bc9a45..ab3804a2 100644 --- a/betterproto2_compiler/pyproject.toml +++ b/betterproto2_compiler/pyproject.toml @@ -122,6 +122,20 @@ python -m grpc.tools.protoc \ google/protobuf/timestamp.proto \ google/protobuf/type.proto \ google/protobuf/wrappers.proto + +python -m grpc.tools.protoc \ + --python_betterproto2_out=tests/output_betterproto_descriptor \ + --python_betterproto2_opt=google_protobuf_descriptors \ + google/protobuf/any.proto \ + google/protobuf/api.proto \ + google/protobuf/duration.proto \ + google/protobuf/empty.proto \ + google/protobuf/field_mask.proto \ + google/protobuf/source_context.proto \ + google/protobuf/struct.proto \ + google/protobuf/timestamp.proto \ + google/protobuf/type.proto \ + google/protobuf/wrappers.proto """ [tool.poe.tasks.typecheck] diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py index 52b69d54..923910d4 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py @@ -216,6 +216,10 @@ def input_filenames(self) -> list[str]: """ return sorted([f.name for f in self.input_files]) + @property + def descriptor(self): + return self.package_proto_obj.SerializeToString() + @dataclass(kw_only=True) class MessageCompiler(ProtoContentBase): @@ -266,6 +270,10 @@ def custom_methods(self) -> list[str]: return methods_source + @property + def descriptor(self): + return self.proto_obj.SerializeToString() + def is_map(proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto) -> bool: """True if proto_field_obj is a map, otherwise False.""" @@ -595,6 +603,10 @@ def py_name(self) -> str: def deprecated(self) -> bool: return bool(self.proto_obj.options and self.proto_obj.options.deprecated) + @property + def descriptor(self): + return self.proto_obj.SerializeToString() + @dataclass(kw_only=True) class ServiceCompiler(ProtoContentBase): diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py index 72435bb1..5ea41c36 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py @@ -81,6 +81,7 @@ def get_settings(plugin_options: list[str]) -> Settings: return Settings( pydantic_dataclasses="pydantic_dataclasses" in plugin_options, + google_protobuf_descriptors="google_protobuf_descriptors" in plugin_options, client_generation=client_generation, server_generation=server_generation, ) diff --git a/betterproto2_compiler/src/betterproto2_compiler/settings.py b/betterproto2_compiler/src/betterproto2_compiler/settings.py index a5269939..2952a383 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/settings.py +++ b/betterproto2_compiler/src/betterproto2_compiler/settings.py @@ -68,6 +68,7 @@ class ServerGeneration(StrEnum): @dataclass class Settings: pydantic_dataclasses: bool + google_protobuf_descriptors: bool client_generation: ClientGeneration server_generation: ServerGeneration diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 index 64a2de57..2655442e 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 @@ -32,6 +32,9 @@ import betterproto2 from betterproto2.grpc.grpclib_server import ServiceBase import grpc import grpclib +{# These imports will be pruned by the compiler if they are unused. #} +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf.descriptor import Descriptor, EnumDescriptor {# Import the message pool of the generated code. #} {% if output_file.package %} @@ -46,3 +49,8 @@ if TYPE_CHECKING: from grpclib.metadata import Deadline betterproto2.check_compiler_version("{{ version }}") + +{% if output_file.settings.google_protobuf_descriptors %} +{# Add descriptors to Google protobuf's default pool to be more drop-in compatible with other libraries. #} +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile({{ output_file.descriptor }}) +{% endif %} diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 index f51a3b39..2d5fa3d3 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 @@ -6,6 +6,13 @@ class {{ enum.py_name | add_to_all }}(betterproto2.Enum): """ {% endif %} + {% if output_file.settings.google_protobuf_descriptors %} + {# Add descriptor class property to be more drop-in compatible with other libraries. #} + @betterproto2.classproperty + def DESCRIPTOR(self) -> EnumDescriptor: + return DESCRIPTOR.enum_types_by_name['{{ enum.proto_name }}'] + {% endif %} + {% for entry in enum.entries %} {{ entry.name }} = {{ entry.value }} {% if entry.comment %} @@ -45,6 +52,13 @@ class {{ message.py_name | add_to_all }}(betterproto2.Message): """ {% endif %} + {% if output_file.settings.google_protobuf_descriptors %} + {# Add descriptor class property to be more drop-in compatible with other libraries. #} + @betterproto2.classproperty + def DESCRIPTOR(self) -> Descriptor: + return DESCRIPTOR.message_types_by_name['{{ message.proto_name }}'] + {% endif %} + {% for field in message.fields %} {{ field.get_field_string() }} {% if field.comment %} diff --git a/betterproto2_compiler/tests/generate.py b/betterproto2_compiler/tests/generate.py index 8e422d6a..0de41ad4 100644 --- a/betterproto2_compiler/tests/generate.py +++ b/betterproto2_compiler/tests/generate.py @@ -10,6 +10,7 @@ inputs_path, output_path_betterproto, output_path_betterproto_pydantic, + output_path_betterproto_descriptor, output_path_reference, protoc, ) @@ -57,23 +58,28 @@ async def generate_test_case_output(test_case_input_path: Path, test_case_name: test_case_output_path_reference = output_path_reference.joinpath(test_case_name) test_case_output_path_betterproto = output_path_betterproto test_case_output_path_betterproto_pyd = output_path_betterproto_pydantic + test_case_output_path_betterproto_desc = output_path_betterproto_descriptor os.makedirs(test_case_output_path_reference, exist_ok=True) os.makedirs(test_case_output_path_betterproto, exist_ok=True) os.makedirs(test_case_output_path_betterproto_pyd, exist_ok=True) + os.makedirs(test_case_output_path_betterproto_desc, exist_ok=True) clear_directory(test_case_output_path_reference) clear_directory(test_case_output_path_betterproto) clear_directory(test_case_output_path_betterproto_pyd) + clear_directory(test_case_output_path_betterproto_desc) ( (ref_out, ref_err, ref_code), (plg_out, plg_err, plg_code), (plg_out_pyd, plg_err_pyd, plg_code_pyd), + (plg_out_desc, plg_err_desc, plg_code_desc), ) = await asyncio.gather( protoc(test_case_input_path, test_case_output_path_reference, True), protoc(test_case_input_path, test_case_output_path_betterproto, False), protoc(test_case_input_path, test_case_output_path_betterproto_pyd, False, True), + protoc(test_case_input_path, test_case_output_path_betterproto_desc, False, False, True), ) if ref_code == 0: @@ -127,7 +133,24 @@ async def generate_test_case_output(test_case_input_path: Path, test_case_name: sys.stderr.buffer.write(plg_err_pyd) sys.stderr.buffer.flush() - return max(ref_code, plg_code, plg_code_pyd) + if plg_code_desc == 0: + print(f"\033[31;1;4mGenerated plugin (google protobuf descriptor) output for {test_case_name!r}\033[0m") + else: + print(f"\033[31;1;4mFailed to generate plugin (google protobuf descriptor) output for {test_case_name!r}\033[0m") + print(plg_err_desc.decode()) + + if verbose: + if plg_out_desc: + print("Plugin stdout:") + sys.stdout.buffer.write(plg_out_desc) + sys.stdout.buffer.flush() + + if plg_err_desc: + print("Plugin stderr:") + sys.stderr.buffer.write(plg_err_desc) + sys.stderr.buffer.flush() + + return max(ref_code, plg_code, plg_code_pyd, plg_code_desc) def main(): diff --git a/betterproto2_compiler/tests/inputs/grpc_reflection_v1/reflection.proto b/betterproto2_compiler/tests/inputs/grpc_reflection_v1/reflection.proto new file mode 100644 index 00000000..f9f349fe --- /dev/null +++ b/betterproto2_compiler/tests/inputs/grpc_reflection_v1/reflection.proto @@ -0,0 +1,146 @@ +// Copyright 2016 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Service exported by server reflection. A more complete description of how +// server reflection works can be found at +// https://github.com/grpc/grpc/blob/master/doc/server-reflection.md +// +// The canonical version of this proto can be found at +// https://github.com/grpc/grpc-proto/blob/master/grpc/reflection/v1/reflection.proto + +syntax = "proto3"; + +package grpc.reflection.v1; + +option go_package = "google.golang.org/grpc/reflection/grpc_reflection_v1"; +option java_multiple_files = true; +option java_package = "io.grpc.reflection.v1"; +option java_outer_classname = "ServerReflectionProto"; + +service ServerReflection { + // The reflection service is structured as a bidirectional stream, ensuring + // all related requests go to a single server. + rpc ServerReflectionInfo(stream ServerReflectionRequest) + returns (stream ServerReflectionResponse); +} + +// The message sent by the client when calling ServerReflectionInfo method. +message ServerReflectionRequest { + string host = 1; + // To use reflection service, the client should set one of the following + // fields in message_request. The server distinguishes requests by their + // defined field and then handles them using corresponding methods. + oneof message_request { + // Find a proto file by the file name. + string file_by_filename = 3; + + // Find the proto file that declares the given fully-qualified symbol name. + // This field should be a fully-qualified symbol name + // (e.g. .[.] or .). + string file_containing_symbol = 4; + + // Find the proto file which defines an extension extending the given + // message type with the given field number. + ExtensionRequest file_containing_extension = 5; + + // Finds the tag numbers used by all known extensions of the given message + // type, and appends them to ExtensionNumberResponse in an undefined order. + // Its corresponding method is best-effort: it's not guaranteed that the + // reflection service will implement this method, and it's not guaranteed + // that this method will provide all extensions. Returns + // StatusCode::UNIMPLEMENTED if it's not implemented. + // This field should be a fully-qualified type name. The format is + // . + string all_extension_numbers_of_type = 6; + + // List the full names of registered services. The content will not be + // checked. + string list_services = 7; + } +} + +// The type name and extension number sent by the client when requesting +// file_containing_extension. +message ExtensionRequest { + // Fully-qualified type name. The format should be . + string containing_type = 1; + int32 extension_number = 2; +} + +// The message sent by the server to answer ServerReflectionInfo method. +message ServerReflectionResponse { + string valid_host = 1; + ServerReflectionRequest original_request = 2; + // The server sets one of the following fields according to the message_request + // in the request. + oneof message_response { + // This message is used to answer file_by_filename, file_containing_symbol, + // file_containing_extension requests with transitive dependencies. + // As the repeated label is not allowed in oneof fields, we use a + // FileDescriptorResponse message to encapsulate the repeated fields. + // The reflection service is allowed to avoid sending FileDescriptorProtos + // that were previously sent in response to earlier requests in the stream. + FileDescriptorResponse file_descriptor_response = 4; + + // This message is used to answer all_extension_numbers_of_type requests. + ExtensionNumberResponse all_extension_numbers_response = 5; + + // This message is used to answer list_services requests. + ListServiceResponse list_services_response = 6; + + // This message is used when an error occurs. + ErrorResponse error_response = 7; + } +} + +// Serialized FileDescriptorProto messages sent by the server answering +// a file_by_filename, file_containing_symbol, or file_containing_extension +// request. +message FileDescriptorResponse { + // Serialized FileDescriptorProto messages. We avoid taking a dependency on + // descriptor.proto, which uses proto2 only features, by making them opaque + // bytes instead. + repeated bytes file_descriptor_proto = 1; +} + +// A list of extension numbers sent by the server answering +// all_extension_numbers_of_type request. +message ExtensionNumberResponse { + // Full name of the base type, including the package name. The format + // is . + string base_type_name = 1; + repeated int32 extension_number = 2; +} + +// A list of ServiceResponse sent by the server answering list_services request. +message ListServiceResponse { + // The information of each service may be expanded in the future, so we use + // ServiceResponse message to encapsulate it. + repeated ServiceResponse service = 1; +} + +// The information of a single service used by ListServiceResponse to answer +// list_services request. +message ServiceResponse { + // Full name of a registered service, including its package name. The format + // is . + string name = 1; +} + +// The error code and error message sent by the server when an error occurs. +message ErrorResponse { + // This field uses the error codes defined in grpc::StatusCode. + int32 error_code = 1; + string error_message = 2; +} diff --git a/betterproto2_compiler/tests/util.py b/betterproto2_compiler/tests/util.py index 0e8366ff..285b8085 100644 --- a/betterproto2_compiler/tests/util.py +++ b/betterproto2_compiler/tests/util.py @@ -10,6 +10,7 @@ output_path_reference = root_path.joinpath("output_reference") output_path_betterproto = root_path.joinpath("output_betterproto") output_path_betterproto_pydantic = root_path.joinpath("output_betterproto_pydantic") +output_path_betterproto_descriptor = root_path.joinpath("output_betterproto_descriptor") def get_directories(path): @@ -17,18 +18,18 @@ def get_directories(path): yield from directories -async def protoc(path: str | Path, output_dir: str | Path, reference: bool = False, pydantic_dataclasses: bool = False): - path: Path = Path(path).resolve() - output_dir: Path = Path(output_dir).resolve() +async def protoc(path: str | Path, output_dir: str | Path, reference: bool = False, pydantic_dataclasses: bool = False, google_protobuf_descriptors: bool = False): + resolved_path: Path = Path(path).resolve() + resolved_output_dir: Path = Path(output_dir).resolve() python_out_option: str = "python_out" if reference else "python_betterproto2_out" command = [ sys.executable, "-m", "grpc.tools.protoc", - f"--proto_path={path.as_posix()}", - f"--{python_out_option}={output_dir.as_posix()}", - *[p.as_posix() for p in path.glob("*.proto")], + f"--proto_path={resolved_path.as_posix()}", + f"--{python_out_option}={resolved_output_dir.as_posix()}", + *[p.as_posix() for p in resolved_path.glob("*.proto")], ] if not reference: @@ -38,10 +39,13 @@ async def protoc(path: str | Path, output_dir: str | Path, reference: bool = Fal if pydantic_dataclasses: command.insert(3, "--python_betterproto2_opt=pydantic_dataclasses") + if google_protobuf_descriptors: + command.insert(3, "--python_betterproto2_opt=google_protobuf_descriptors") + proc = await asyncio.create_subprocess_exec( *command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await proc.communicate() - return stdout, stderr, proc.returncode + return stdout, stderr, proc.returncode or 0 From 44e023e53f7a40dde50611535a73e3404e9ae2c9 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Sat, 14 Jun 2025 19:07:50 -0400 Subject: [PATCH 02/17] fix for dependent pkgs --- .../grpc/test_message_enum_descriptors.py | 14 +-- .../compile/importing.py | 94 ++++++++++++------- .../betterproto2_compiler/plugin/models.py | 50 +++++++++- .../templates/header.py.j2 | 9 +- 4 files changed, 125 insertions(+), 42 deletions(-) diff --git a/betterproto2/tests/grpc/test_message_enum_descriptors.py b/betterproto2/tests/grpc/test_message_enum_descriptors.py index 100c71b7..95428721 100644 --- a/betterproto2/tests/grpc/test_message_enum_descriptors.py +++ b/betterproto2/tests/grpc/test_message_enum_descriptors.py @@ -1,17 +1,17 @@ import pytest -from tests.output_betterproto.service import ThingType, DoThingRequest -from tests.output_betterproto_descriptor.service import ThingType as ThingTypeWithDesc, DoThingRequest as DoThingRequestWithDesc +from tests.output_betterproto.import_cousin_package_same_name.test.subpackage import Test +from tests.output_betterproto_descriptor.import_cousin_package_same_name.test.subpackage import Test as TestWithDesc +# importing the cousin should cause no descriptor pool errors since the subpackage imports it once already +from tests.output_betterproto_descriptor.import_cousin_package_same_name.cousin.subpackage import CousinMessage def test_message_enum_descriptors(): # Normally descriptors are not available as they require protobuf support # to inteoperate with other libraries. with pytest.raises(AttributeError): - ThingType.DESCRIPTOR.full_name - with pytest.raises(AttributeError): - DoThingRequest.DESCRIPTOR.full_name + Test.DESCRIPTOR.full_name # But the python_betterproto2_opt=google_protobuf_descriptors option # will add them in as long as protobuf is depended on. - assert ThingTypeWithDesc.DESCRIPTOR.full_name == "service.ThingType" - assert DoThingRequestWithDesc.DESCRIPTOR.full_name == "service.DoThingRequest" + assert TestWithDesc.DESCRIPTOR.full_name == "import_cousin_package_same_name.test.subpackage.Test" + assert CousinMessage.DESCRIPTOR.full_name == "import_cousin_package_same_name.cousin.subpackage.CousinMessage" diff --git a/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py b/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py index 9e4901db..cdbfe05a 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py +++ b/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import sys from typing import ( TYPE_CHECKING, ) @@ -55,6 +56,33 @@ def parse_source_type_name(field_type_name: str, request: PluginRequestCompiler) raise ValueError(f"can't find type name: {field_type_name}") +def get_symbol_reference( + *, + package: str, + imports: set, + source_package: str, + symbol: str, + request: PluginRequestCompiler, + import_suffx: str = "" +) -> tuple[str, str | None]: + """ + Return a Python symbol within a proto package. Adds the import if + necessary and returns it as well for usage. Unwraps well known type if required. + """ + current_package: list[str] = package.split(".") if package else [] + py_package: list[str] = source_package.split(".") if source_package else [] + + if py_package == current_package: + return (reference_sibling(symbol), None) + + if py_package[: len(current_package)] == current_package: + return reference_descendent(current_package, imports, py_package, symbol, import_suffx) + + if current_package[: len(py_package)] == py_package: + return reference_ancestor(current_package, imports, py_package, symbol, import_suffx) + + return reference_cousin(current_package, imports, py_package, symbol, import_suffx) + def get_type_reference( *, package: str, @@ -73,30 +101,26 @@ def get_type_reference( if wrap and (source_package, source_type) in WRAPPED_TYPES: return WRAPPED_TYPES[(source_package, source_type)] - current_package: list[str] = package.split(".") if package else [] - py_package: list[str] = source_package.split(".") if source_package else [] py_type: str = pythonize_class_name(source_type) - - if py_package == current_package: - return reference_sibling(py_type) - - if py_package[: len(current_package)] == current_package: - return reference_descendent(current_package, imports, py_package, py_type) - - if current_package[: len(py_package)] == py_package: - return reference_ancestor(current_package, imports, py_package, py_type) - - return reference_cousin(current_package, imports, py_package, py_type) + (ref, _) = get_symbol_reference( + package=package, + imports=imports, + source_package=source_package, + symbol=py_type, + request=request, + ) + return ref -def reference_absolute(imports: set[str], py_package: list[str], py_type: str) -> str: +def reference_absolute(imports: set[str], py_package: list[str], py_type: str) -> tuple[str, str]: """ Returns a reference to a python type located in the root, i.e. sys.path. """ string_import = ".".join(py_package) string_alias = "__".join([safe_snake_case(name) for name in py_package]) - imports.add(f"import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + import_to_add = f"import {string_import} as {string_alias}" + imports.add(import_to_add) + return (f"{string_alias}.{py_type}", import_to_add) def reference_sibling(py_type: str) -> str: @@ -106,7 +130,7 @@ def reference_sibling(py_type: str) -> str: return f"{py_type}" -def reference_descendent(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str: +def reference_descendent(current_package: list[str], imports: set[str], py_package: list[str], py_type: str, import_suffix: str = "") -> tuple[str, str]: """ Returns a reference to a python type in a package that is a descendent of the current package, and adds the required import that is aliased to avoid name @@ -116,15 +140,17 @@ def reference_descendent(current_package: list[str], imports: set[str], py_packa string_from = ".".join(importing_descendent[:-1]) string_import = importing_descendent[-1] if string_from: - string_alias = "_".join(importing_descendent) - imports.add(f"from .{string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + string_alias = f'{"_".join(importing_descendent)}{import_suffix}' + import_to_add = f"from .{string_from} import {string_import} as {string_alias}" + imports.add(import_to_add) + return (f"{string_alias}.{py_type}", import_to_add) else: - imports.add(f"from . import {string_import}") - return f"{string_import}.{py_type}" + import_to_add = f"from . import {string_import}" + imports.add(import_to_add) + return (f"{string_import}.{py_type}", import_to_add) -def reference_ancestor(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str: +def reference_ancestor(current_package: list[str], imports: set[str], py_package: list[str], py_type: str, import_suffix: str = "") -> tuple[str, str]: """ Returns a reference to a python type in a package which is an ancestor to the current package, and adds the required import that is aliased (if possible) to avoid @@ -135,17 +161,19 @@ def reference_ancestor(current_package: list[str], imports: set[str], py_package distance_up = len(current_package) - len(py_package) if py_package: string_import = py_package[-1] - string_alias = f"_{'_' * distance_up}{string_import}__" + string_alias = f"_{'_' * distance_up}{string_import}__{import_suffix}" string_from = f"..{'.' * distance_up}" - imports.add(f"from {string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + import_to_add = f"from {string_from} import {string_import} as {string_alias}" + imports.add(import_to_add) + return (f"{string_alias}.{py_type}", import_to_add) else: - string_alias = f"{'_' * distance_up}{py_type}__" - imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}") - return string_alias + string_alias = f"{'_' * distance_up}{py_type}__{import_suffix}" + import_to_add = f"from .{'.' * distance_up} import {py_type} as {string_alias}" + imports.add(import_to_add) + return (string_alias, import_to_add) -def reference_cousin(current_package: list[str], imports: set[str], py_package: list[str], py_type: str) -> str: +def reference_cousin(current_package: list[str], imports: set[str], py_package: list[str], py_type: str, import_suffix: str = "") -> tuple[str, str]: """ Returns a reference to a python type in a package that is not descendent, ancestor or sibling, and adds the required import that is aliased to avoid name conflicts. @@ -160,6 +188,8 @@ def reference_cousin(current_package: list[str], imports: set[str], py_package: f"{'_' * distance_up}" + "__".join([safe_snake_case(name) for name in py_package[len(shared_ancestry) :]]) + "__" + + import_suffix ) - imports.add(f"from {string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + import_to_add = f"from {string_from} import {string_import} as {string_alias}" + imports.add(import_to_add) + return (f"{string_alias}.{py_type}", import_to_add) diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py index 923910d4..ee4690ae 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py @@ -34,7 +34,7 @@ from betterproto2 import unwrap -from betterproto2_compiler.compile.importing import get_type_reference, parse_source_type_name +from betterproto2_compiler.compile.importing import get_symbol_reference, get_type_reference, parse_source_type_name from betterproto2_compiler.compile.naming import ( pythonize_class_name, pythonize_enum_member_name, @@ -216,8 +216,42 @@ def input_filenames(self) -> list[str]: """ return sorted([f.name for f in self.input_files]) + @property + def dependency_imports(self): + """Proto dependencies as Python imports. + + Returns + ------- + list[str] + Imports for each proto dependency of this package resolved to the output + names, not the input file paths. + """ + def dep_to_pkg_import(dep: str) -> str: + for (output_name, pkg) in self.parent_request.output_packages.items(): + if dep in pkg.input_filenames: + ref, ref_import = get_symbol_reference( + package=self.package, + imports=self.imports_end, + source_package=output_name, + request=self.parent_request, + symbol="_COMPILER_VERSION", + import_suffx="_dep" + ) + + # import and check compiler version for safety and to avoid this import being removed. + return f'{ref_import}\nbetterproto2.check_compiler_version({ref})' + raise ValueError(f"cannot find which output package {dep} belongs to") + return [dep_to_pkg_import(dep) for dep in self.package_proto_obj.dependency] + @property def descriptor(self): + """Google protobuf library descriptor. + + Returns + ------- + str + A binary string of the package's proto descriptor. + """ return self.package_proto_obj.SerializeToString() @@ -272,6 +306,13 @@ def custom_methods(self) -> list[str]: @property def descriptor(self): + """Google protobuf library descriptor. + + Returns + ------- + str + A binary string of the message's proto descriptor. + """ return self.proto_obj.SerializeToString() @@ -605,6 +646,13 @@ def deprecated(self) -> bool: @property def descriptor(self): + """Google protobuf library descriptor. + + Returns + ------- + str + A binary string of the enum's proto descriptor. + """ return self.proto_obj.SerializeToString() diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 index 2655442e..40e18185 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 @@ -32,7 +32,6 @@ import betterproto2 from betterproto2.grpc.grpclib_server import ServiceBase import grpc import grpclib -{# These imports will be pruned by the compiler if they are unused. #} from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf.descriptor import Descriptor, EnumDescriptor @@ -48,9 +47,15 @@ if TYPE_CHECKING: from betterproto2.grpc.grpclib_client import MetadataLike from grpclib.metadata import Deadline -betterproto2.check_compiler_version("{{ version }}") +_COMPILER_VERSION="{{ version }}" +betterproto2.check_compiler_version(_COMPILER_VERSION) {% if output_file.settings.google_protobuf_descriptors %} + +{% for import in output_file.dependency_imports -%} +{{ import }} +{% endfor %} + {# Add descriptors to Google protobuf's default pool to be more drop-in compatible with other libraries. #} DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile({{ output_file.descriptor }}) {% endif %} From 5c7548e64c63b5971c2c23eb7d541f62db67b2d4 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Sat, 14 Jun 2025 20:21:14 -0400 Subject: [PATCH 03/17] fix multiple descriptors --- .../betterproto2_compiler/plugin/models.py | 36 +++++++++++++++---- .../templates/header.py.j2 | 7 ++-- .../templates/template.py.j2 | 4 +-- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py index ee4690ae..d1e177e9 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py @@ -241,19 +241,21 @@ def dep_to_pkg_import(dep: str) -> str: # import and check compiler version for safety and to avoid this import being removed. return f'{ref_import}\nbetterproto2.check_compiler_version({ref})' raise ValueError(f"cannot find which output package {dep} belongs to") - return [dep_to_pkg_import(dep) for dep in self.package_proto_obj.dependency] + return {dep_to_pkg_import(dep) for dep in self.package_proto_obj.dependency} + + def get_descriptor_name(self, source_file: FileDescriptorProto): + return f'{source_file.name.replace('/', '_').replace('.', '_').upper()}_DESCRIPTOR' @property - def descriptor(self): - """Google protobuf library descriptor. + def descriptors(self): + """Google protobuf library descriptors. Returns ------- str - A binary string of the package's proto descriptor. + A list of pool registrations for proto descriptors. """ - return self.package_proto_obj.SerializeToString() - + return '\n'.join([f'{self.get_descriptor_name(input_file)} = _descriptor_pool.Default().AddSerializedFile({input_file.SerializeToString()})' for input_file in self.input_files]) @dataclass(kw_only=True) class MessageCompiler(ProtoContentBase): @@ -304,6 +306,17 @@ def custom_methods(self) -> list[str]: return methods_source + @property + def descriptor_name(self): + """Google protobuf library descriptor name. + + Returns + ------- + str + The Python name of the descriptor to reference. + """ + return self.output_file.get_descriptor_name(self.source_file) + @property def descriptor(self): """Google protobuf library descriptor. @@ -644,6 +657,17 @@ def py_name(self) -> str: def deprecated(self) -> bool: return bool(self.proto_obj.options and self.proto_obj.options.deprecated) + @property + def descriptor_name(self): + """Google protobuf library descriptor name. + + Returns + ------- + str + The Python name of the descriptor to reference. + """ + return self.output_file.get_descriptor_name(self.source_file) + @property def descriptor(self): """Google protobuf library descriptor. diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 index 40e18185..67ec43f7 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 @@ -52,10 +52,11 @@ betterproto2.check_compiler_version(_COMPILER_VERSION) {% if output_file.settings.google_protobuf_descriptors %} -{% for import in output_file.dependency_imports -%} -{{ import }} +{% for import in output_file.dependency_imports %} +{{ import.rstrip() }} +{{ "\n" }} {% endfor %} {# Add descriptors to Google protobuf's default pool to be more drop-in compatible with other libraries. #} -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile({{ output_file.descriptor }}) +{{ output_file.descriptors }} {% endif %} diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 index 2d5fa3d3..aa6d71d4 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 @@ -10,7 +10,7 @@ class {{ enum.py_name | add_to_all }}(betterproto2.Enum): {# Add descriptor class property to be more drop-in compatible with other libraries. #} @betterproto2.classproperty def DESCRIPTOR(self) -> EnumDescriptor: - return DESCRIPTOR.enum_types_by_name['{{ enum.proto_name }}'] + return {{ enum.descriptor_name }}.enum_types_by_name['{{ enum.proto_name }}'] {% endif %} {% for entry in enum.entries %} @@ -56,7 +56,7 @@ class {{ message.py_name | add_to_all }}(betterproto2.Message): {# Add descriptor class property to be more drop-in compatible with other libraries. #} @betterproto2.classproperty def DESCRIPTOR(self) -> Descriptor: - return DESCRIPTOR.message_types_by_name['{{ message.proto_name }}'] + return {{ message.descriptor_name }}.message_types_by_name['{{ message.proto_name }}'] {% endif %} {% for field in message.fields %} From 2227f3633ce24881296c9853ac4828aa81cc2387 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Sat, 14 Jun 2025 21:55:06 -0400 Subject: [PATCH 04/17] driveby --- .../src/betterproto2_compiler/templates/template.py.j2 | 3 +++ 1 file changed, 3 insertions(+) diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 index aa6d71d4..6a6addef 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 @@ -141,6 +141,9 @@ class {{ (service.py_name + "Base") | add_to_all }}(ServiceBase): {% endif %} raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) + {% if method.server_streaming %} + yield {{ method.py_output_message_type }}() + {% endif %} {% endfor %} From af37fdcda24d16b53995c3186a2077c03c345aac Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Sat, 14 Jun 2025 21:56:25 -0400 Subject: [PATCH 05/17] comment --- .../src/betterproto2_compiler/templates/template.py.j2 | 1 + 1 file changed, 1 insertion(+) diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 index 6a6addef..192d821f 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 @@ -142,6 +142,7 @@ class {{ (service.py_name + "Base") | add_to_all }}(ServiceBase): raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) {% if method.server_streaming %} + {# yielding here changes the return type from a coroutine to an async_generator #} yield {{ method.py_output_message_type }}() {% endif %} From 91e9be4fda830cc3ca57addd3414af163a06dd7d Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Sat, 14 Jun 2025 22:38:29 -0400 Subject: [PATCH 06/17] order imports --- .../src/betterproto2_compiler/plugin/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py index d1e177e9..045fc942 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py @@ -226,7 +226,7 @@ def dependency_imports(self): Imports for each proto dependency of this package resolved to the output names, not the input file paths. """ - def dep_to_pkg_import(dep: str) -> str: + def dep_to_pkg_import(dep: str) -> tuple[str, str]: for (output_name, pkg) in self.parent_request.output_packages.items(): if dep in pkg.input_filenames: ref, ref_import = get_symbol_reference( @@ -239,9 +239,9 @@ def dep_to_pkg_import(dep: str) -> str: ) # import and check compiler version for safety and to avoid this import being removed. - return f'{ref_import}\nbetterproto2.check_compiler_version({ref})' + return (ref, f'{ref_import}\nbetterproto2.check_compiler_version({ref})') raise ValueError(f"cannot find which output package {dep} belongs to") - return {dep_to_pkg_import(dep) for dep in self.package_proto_obj.dependency} + return dict([dep_to_pkg_import(dep) for dep in self.package_proto_obj.dependency]).values() def get_descriptor_name(self, source_file: FileDescriptorProto): return f'{source_file.name.replace('/', '_').replace('.', '_').upper()}_DESCRIPTOR' From e53ea3646f19baab4891b03a68f054b08c8c69e9 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Sun, 15 Jun 2025 09:55:47 -0400 Subject: [PATCH 07/17] fix nested descriptors and stop mutating file input --- betterproto2/tests/grpc/test_grpclib_reflection.py | 2 +- betterproto2/tests/test_deprecated.py | 7 +++++++ .../src/betterproto2_compiler/plugin/models.py | 14 +++++++++++--- .../src/betterproto2_compiler/plugin/parser.py | 8 ++++---- .../betterproto2_compiler/templates/template.py.j2 | 6 +++--- .../tests/inputs/deprecated/deprecated.proto | 3 +++ 6 files changed, 29 insertions(+), 11 deletions(-) diff --git a/betterproto2/tests/grpc/test_grpclib_reflection.py b/betterproto2/tests/grpc/test_grpclib_reflection.py index 4ed70bc1..d7619c7e 100644 --- a/betterproto2/tests/grpc/test_grpclib_reflection.py +++ b/betterproto2/tests/grpc/test_grpclib_reflection.py @@ -62,7 +62,7 @@ async def test_grpclib_reflection(): import tests.output_betterproto_descriptor.example_service as example_service_with_desc requests.put(ServerReflectionRequest(file_containing_symbol="example_service.Test")) response = await anext(responses) - expected = descriptor_pb2.FileDescriptorProto.FromString(example_service_with_desc.DESCRIPTOR.serialized_pb) + expected = descriptor_pb2.FileDescriptorProto.FromString(example_service_with_desc.EXAMPLE_SERVICE_PROTO_DESCRIPTOR.serialized_pb) assert response.error_response is None assert response.file_descriptor_response is not None assert len(response.file_descriptor_response.file_descriptor_proto) == 1 diff --git a/betterproto2/tests/test_deprecated.py b/betterproto2/tests/test_deprecated.py index ea16d370..d7819c6b 100644 --- a/betterproto2/tests/test_deprecated.py +++ b/betterproto2/tests/test_deprecated.py @@ -25,6 +25,13 @@ def test_deprecated_message(): assert len(record) == 1 assert str(record[0].message) == f"{Message.__name__} is deprecated" +def test_deprecated_nested_message(): + with pytest.warns(DeprecationWarning) as record: + Message(value="hello") + + assert len(record) == 1 + assert str(record[0].message) == f"{Message.__name__} is deprecated" + def test_message_with_deprecated_field(message): with pytest.warns(DeprecationWarning) as record: diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py index 045fc942..237bfa5b 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py @@ -271,9 +271,13 @@ class MessageCompiler(ProtoContentBase): def proto_name(self) -> str: return self.proto_obj.name + @property + def prefixed_proto_name(self) -> str: + return self.proto_obj.prefixed_name + @property def py_name(self) -> str: - return pythonize_class_name(self.proto_name) + return pythonize_class_name(self.proto_obj.prefixed_name) @property def deprecated(self) -> bool: @@ -283,7 +287,7 @@ def deprecated(self) -> bool: def deprecated_fields(self) -> Iterator[str]: for f in self.fields: if f.deprecated: - yield f.py_name + yield f.proto_name @property def has_deprecated_fields(self) -> bool: @@ -649,9 +653,13 @@ def __post_init__(self) -> None: def proto_name(self) -> str: return self.proto_obj.name + @property + def prefixed_proto_name(self) -> str: + return self.proto_obj.prefixed_name + @property def py_name(self) -> str: - return pythonize_class_name(self.proto_name) + return pythonize_class_name(self.proto_obj.prefixed_name) @property def deprecated(self) -> bool: diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py index 5ea41c36..085b3c9d 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py @@ -44,10 +44,10 @@ def _traverse( ) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]: for i, item in enumerate(items): # Adjust the name since we flatten the hierarchy. - # Todo: don't change the name, but include full name in returned tuple should_rename = not isinstance(item, DescriptorProto) or not item.options or not item.options.map_entry - item.name = next_prefix = f"{prefix}.{item.name}" if prefix and should_rename else item.name + # Record prefixed name but *do not* mutate original file. + item.prefixed_name = next_prefix = f"{prefix}.{item.name}" if prefix and should_rename else item.name yield item, [*path, i] if isinstance(item, DescriptorProto): @@ -192,7 +192,7 @@ def read_protobuf_type( proto_obj=item, path=path, ) - output_package.messages[message_data.proto_name] = message_data + output_package.messages[message_data.prefixed_proto_name] = message_data for index, field in enumerate(item.field): if is_map(field, item): @@ -247,7 +247,7 @@ def read_protobuf_type( proto_obj=item, path=path, ) - output_package.enums[enum.proto_name] = enum + output_package.enums[enum.prefixed_proto_name] = enum def read_protobuf_service( diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 index 192d821f..2f705cdf 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 @@ -10,7 +10,7 @@ class {{ enum.py_name | add_to_all }}(betterproto2.Enum): {# Add descriptor class property to be more drop-in compatible with other libraries. #} @betterproto2.classproperty def DESCRIPTOR(self) -> EnumDescriptor: - return {{ enum.descriptor_name }}.enum_types_by_name['{{ enum.proto_name }}'] + return {{ enum.descriptor_name }}.enum_types_by_name['{{ enum.prefixed_proto_name }}'] {% endif %} {% for entry in enum.entries %} @@ -56,7 +56,7 @@ class {{ message.py_name | add_to_all }}(betterproto2.Message): {# Add descriptor class property to be more drop-in compatible with other libraries. #} @betterproto2.classproperty def DESCRIPTOR(self) -> Descriptor: - return {{ message.descriptor_name }}.message_types_by_name['{{ message.proto_name }}'] + return {{ message.descriptor_name }}.message_types_by_name['{{ message.prefixed_proto_name }}'] {% endif %} {% for field in message.fields %} @@ -95,7 +95,7 @@ class {{ message.py_name | add_to_all }}(betterproto2.Message): {{ method_source }} {% endfor %} -default_message_pool.register_message("{{ output_file.package }}", "{{ message.proto_name }}", {{ message.py_name }}) +default_message_pool.register_message("{{ output_file.package }}", "{{ message.prefixed_proto_name }}", {{ message.py_name }}) {% endfor %} diff --git a/betterproto2_compiler/tests/inputs/deprecated/deprecated.proto b/betterproto2_compiler/tests/inputs/deprecated/deprecated.proto index f504d03a..2e64c621 100644 --- a/betterproto2_compiler/tests/inputs/deprecated/deprecated.proto +++ b/betterproto2_compiler/tests/inputs/deprecated/deprecated.proto @@ -6,6 +6,9 @@ package deprecated; message Test { Message message = 1 [deprecated=true]; int32 value = 2; + message Nested { + int32 nested_value = 1 [deprecated=true]; + } } message Message { From ce69a386b1a0302b7b35c1f5d902ee56aeaac7b4 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Mon, 16 Jun 2025 07:28:17 -0400 Subject: [PATCH 08/17] lint --- .../tests/grpc/test_grpclib_reflection.py | 34 +++++++++++++------ .../grpc/test_message_enum_descriptors.py | 4 ++- betterproto2/tests/test_deprecated.py | 1 + .../compile/importing.py | 18 ++++++---- .../betterproto2_compiler/plugin/models.py | 18 +++++++--- betterproto2_compiler/tests/generate.py | 6 ++-- betterproto2_compiler/tests/util.py | 8 ++++- 7 files changed, 64 insertions(+), 25 deletions(-) diff --git a/betterproto2/tests/grpc/test_grpclib_reflection.py b/betterproto2/tests/grpc/test_grpclib_reflection.py index d7619c7e..99cdbbf3 100644 --- a/betterproto2/tests/grpc/test_grpclib_reflection.py +++ b/betterproto2/tests/grpc/test_grpclib_reflection.py @@ -1,21 +1,28 @@ import asyncio from typing import Generic, TypeVar -import grpclib -from grpclib.reflection.service import ServerReflection import pytest -from grpclib.testing import ChannelFor from google.protobuf import descriptor_pb2 - -from tests.grpc.async_channel import AsyncChannel -from tests.output_betterproto.grpc.reflection.v1 import ErrorResponse, ListServiceResponse, ServiceResponse, ServerReflectionRequest, ServerReflectionStub +from grpclib.reflection.service import ServerReflection +from grpclib.testing import ChannelFor from tests.output_betterproto.example_service import TestBase +from tests.output_betterproto.grpc.reflection.v1 import ( + ErrorResponse, + ListServiceResponse, + ServerReflectionRequest, + ServerReflectionStub, + ServiceResponse, +) + class TestService(TestBase): pass + T = TypeVar("T") + + class AsyncIterableQueue(Generic[T]): def __init__(self): self._queue = asyncio.Queue() @@ -36,6 +43,7 @@ async def __anext__(self) -> T: except asyncio.QueueShutDown: raise StopAsyncIteration + @pytest.mark.asyncio async def test_grpclib_reflection(): service = TestService() @@ -48,25 +56,31 @@ async def test_grpclib_reflection(): requests.put(ServerReflectionRequest(list_services="")) response = await anext(responses) assert response.list_services_response == ListServiceResponse( - service=[ServiceResponse(name='example_service.Test')]) + service=[ServiceResponse(name="example_service.Test")] + ) # list methods # should fail before we've added descriptors to the protobuf pool requests.put(ServerReflectionRequest(file_containing_symbol="example_service.Test")) response = await anext(responses) - assert response.error_response == ErrorResponse(error_code=5, error_message='not found') + assert response.error_response == ErrorResponse(error_code=5, error_message="not found") assert response.file_descriptor_response is None # now it should work import tests.output_betterproto_descriptor.example_service as example_service_with_desc + requests.put(ServerReflectionRequest(file_containing_symbol="example_service.Test")) response = await anext(responses) - expected = descriptor_pb2.FileDescriptorProto.FromString(example_service_with_desc.EXAMPLE_SERVICE_PROTO_DESCRIPTOR.serialized_pb) + expected = descriptor_pb2.FileDescriptorProto.FromString( + example_service_with_desc.EXAMPLE_SERVICE_PROTO_DESCRIPTOR.serialized_pb + ) assert response.error_response is None assert response.file_descriptor_response is not None assert len(response.file_descriptor_response.file_descriptor_proto) == 1 - actual = descriptor_pb2.FileDescriptorProto.FromString(response.file_descriptor_response.file_descriptor_proto[0]) + actual = descriptor_pb2.FileDescriptorProto.FromString( + response.file_descriptor_response.file_descriptor_proto[0] + ) assert actual == expected requests.close() diff --git a/betterproto2/tests/grpc/test_message_enum_descriptors.py b/betterproto2/tests/grpc/test_message_enum_descriptors.py index 95428721..a7922820 100644 --- a/betterproto2/tests/grpc/test_message_enum_descriptors.py +++ b/betterproto2/tests/grpc/test_message_enum_descriptors.py @@ -1,9 +1,11 @@ import pytest from tests.output_betterproto.import_cousin_package_same_name.test.subpackage import Test -from tests.output_betterproto_descriptor.import_cousin_package_same_name.test.subpackage import Test as TestWithDesc + # importing the cousin should cause no descriptor pool errors since the subpackage imports it once already from tests.output_betterproto_descriptor.import_cousin_package_same_name.cousin.subpackage import CousinMessage +from tests.output_betterproto_descriptor.import_cousin_package_same_name.test.subpackage import Test as TestWithDesc + def test_message_enum_descriptors(): # Normally descriptors are not available as they require protobuf support diff --git a/betterproto2/tests/test_deprecated.py b/betterproto2/tests/test_deprecated.py index d7819c6b..d9a81c02 100644 --- a/betterproto2/tests/test_deprecated.py +++ b/betterproto2/tests/test_deprecated.py @@ -25,6 +25,7 @@ def test_deprecated_message(): assert len(record) == 1 assert str(record[0].message) == f"{Message.__name__} is deprecated" + def test_deprecated_nested_message(): with pytest.warns(DeprecationWarning) as record: Message(value="hello") diff --git a/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py b/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py index cdbfe05a..4f6873c3 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py +++ b/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py @@ -1,7 +1,6 @@ from __future__ import annotations import os -import sys from typing import ( TYPE_CHECKING, ) @@ -63,7 +62,7 @@ def get_symbol_reference( source_package: str, symbol: str, request: PluginRequestCompiler, - import_suffx: str = "" + import_suffx: str = "", ) -> tuple[str, str | None]: """ Return a Python symbol within a proto package. Adds the import if @@ -83,6 +82,7 @@ def get_symbol_reference( return reference_cousin(current_package, imports, py_package, symbol, import_suffx) + def get_type_reference( *, package: str, @@ -130,7 +130,9 @@ def reference_sibling(py_type: str) -> str: return f"{py_type}" -def reference_descendent(current_package: list[str], imports: set[str], py_package: list[str], py_type: str, import_suffix: str = "") -> tuple[str, str]: +def reference_descendent( + current_package: list[str], imports: set[str], py_package: list[str], py_type: str, import_suffix: str = "" +) -> tuple[str, str]: """ Returns a reference to a python type in a package that is a descendent of the current package, and adds the required import that is aliased to avoid name @@ -140,7 +142,7 @@ def reference_descendent(current_package: list[str], imports: set[str], py_packa string_from = ".".join(importing_descendent[:-1]) string_import = importing_descendent[-1] if string_from: - string_alias = f'{"_".join(importing_descendent)}{import_suffix}' + string_alias = f"{'_'.join(importing_descendent)}{import_suffix}" import_to_add = f"from .{string_from} import {string_import} as {string_alias}" imports.add(import_to_add) return (f"{string_alias}.{py_type}", import_to_add) @@ -150,7 +152,9 @@ def reference_descendent(current_package: list[str], imports: set[str], py_packa return (f"{string_import}.{py_type}", import_to_add) -def reference_ancestor(current_package: list[str], imports: set[str], py_package: list[str], py_type: str, import_suffix: str = "") -> tuple[str, str]: +def reference_ancestor( + current_package: list[str], imports: set[str], py_package: list[str], py_type: str, import_suffix: str = "" +) -> tuple[str, str]: """ Returns a reference to a python type in a package which is an ancestor to the current package, and adds the required import that is aliased (if possible) to avoid @@ -173,7 +177,9 @@ def reference_ancestor(current_package: list[str], imports: set[str], py_package return (string_alias, import_to_add) -def reference_cousin(current_package: list[str], imports: set[str], py_package: list[str], py_type: str, import_suffix: str = "") -> tuple[str, str]: +def reference_cousin( + current_package: list[str], imports: set[str], py_package: list[str], py_type: str, import_suffix: str = "" +) -> tuple[str, str]: """ Returns a reference to a python type in a package that is not descendent, ancestor or sibling, and adds the required import that is aliased to avoid name conflicts. diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py index 237bfa5b..f92cb56a 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py @@ -226,8 +226,9 @@ def dependency_imports(self): Imports for each proto dependency of this package resolved to the output names, not the input file paths. """ + def dep_to_pkg_import(dep: str) -> tuple[str, str]: - for (output_name, pkg) in self.parent_request.output_packages.items(): + for output_name, pkg in self.parent_request.output_packages.items(): if dep in pkg.input_filenames: ref, ref_import = get_symbol_reference( package=self.package, @@ -235,16 +236,17 @@ def dep_to_pkg_import(dep: str) -> tuple[str, str]: source_package=output_name, request=self.parent_request, symbol="_COMPILER_VERSION", - import_suffx="_dep" + import_suffx="_dep", ) # import and check compiler version for safety and to avoid this import being removed. - return (ref, f'{ref_import}\nbetterproto2.check_compiler_version({ref})') + return (ref, f"{ref_import}\nbetterproto2.check_compiler_version({ref})") raise ValueError(f"cannot find which output package {dep} belongs to") + return dict([dep_to_pkg_import(dep) for dep in self.package_proto_obj.dependency]).values() def get_descriptor_name(self, source_file: FileDescriptorProto): - return f'{source_file.name.replace('/', '_').replace('.', '_').upper()}_DESCRIPTOR' + return f"{source_file.name.replace('/', '_').replace('.', '_').upper()}_DESCRIPTOR" @property def descriptors(self): @@ -255,7 +257,13 @@ def descriptors(self): str A list of pool registrations for proto descriptors. """ - return '\n'.join([f'{self.get_descriptor_name(input_file)} = _descriptor_pool.Default().AddSerializedFile({input_file.SerializeToString()})' for input_file in self.input_files]) + return "\n".join( + [ + f"{self.get_descriptor_name(f)} = _descriptor_pool.Default().AddSerializedFile({f.SerializeToString()})" + for f in self.input_files + ] + ) + @dataclass(kw_only=True) class MessageCompiler(ProtoContentBase): diff --git a/betterproto2_compiler/tests/generate.py b/betterproto2_compiler/tests/generate.py index 0de41ad4..16b5297b 100644 --- a/betterproto2_compiler/tests/generate.py +++ b/betterproto2_compiler/tests/generate.py @@ -9,8 +9,8 @@ get_directories, inputs_path, output_path_betterproto, - output_path_betterproto_pydantic, output_path_betterproto_descriptor, + output_path_betterproto_pydantic, output_path_reference, protoc, ) @@ -136,7 +136,9 @@ async def generate_test_case_output(test_case_input_path: Path, test_case_name: if plg_code_desc == 0: print(f"\033[31;1;4mGenerated plugin (google protobuf descriptor) output for {test_case_name!r}\033[0m") else: - print(f"\033[31;1;4mFailed to generate plugin (google protobuf descriptor) output for {test_case_name!r}\033[0m") + print( + f"\033[31;1;4mFailed to generate plugin (google protobuf descriptor) output for {test_case_name!r}\033[0m" + ) print(plg_err_desc.decode()) if verbose: diff --git a/betterproto2_compiler/tests/util.py b/betterproto2_compiler/tests/util.py index 285b8085..8e12dda7 100644 --- a/betterproto2_compiler/tests/util.py +++ b/betterproto2_compiler/tests/util.py @@ -18,7 +18,13 @@ def get_directories(path): yield from directories -async def protoc(path: str | Path, output_dir: str | Path, reference: bool = False, pydantic_dataclasses: bool = False, google_protobuf_descriptors: bool = False): +async def protoc( + path: str | Path, + output_dir: str | Path, + reference: bool = False, + pydantic_dataclasses: bool = False, + google_protobuf_descriptors: bool = False, +): resolved_path: Path = Path(path).resolve() resolved_output_dir: Path = Path(output_dir).resolve() python_out_option: str = "python_out" if reference else "python_betterproto2_out" From 8b4ce01f6ff949faf9e1d58f804424a73992e4a5 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Mon, 16 Jun 2025 07:41:27 -0400 Subject: [PATCH 09/17] fix --- betterproto2/tests/grpc/test_grpclib_reflection.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/betterproto2/tests/grpc/test_grpclib_reflection.py b/betterproto2/tests/grpc/test_grpclib_reflection.py index 99cdbbf3..d5334aca 100644 --- a/betterproto2/tests/grpc/test_grpclib_reflection.py +++ b/betterproto2/tests/grpc/test_grpclib_reflection.py @@ -24,6 +24,8 @@ class TestService(TestBase): class AsyncIterableQueue(Generic[T]): + CLOSED_SENTINEL = object() + def __init__(self): self._queue = asyncio.Queue() self._done = asyncio.Event() @@ -32,16 +34,16 @@ def put(self, item: T): self._queue.put_nowait(item) def close(self): - self._queue.shutdown() + self._queue.put_nowait(self.CLOSED_SENTINEL) def __aiter__(self): return self async def __anext__(self) -> T: - try: - return await self._queue.get() - except asyncio.QueueShutDown: + val = await self._queue.get() + if val is self.CLOSED_SENTINEL: raise StopAsyncIteration + return val @pytest.mark.asyncio From a85929c273fc58997933573fe670a7e3b1011ffe Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Mon, 16 Jun 2025 12:29:59 -0400 Subject: [PATCH 10/17] fix default groups --- betterproto2/pyproject.toml | 2 +- betterproto2_compiler/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/betterproto2/pyproject.toml b/betterproto2/pyproject.toml index 65798b8e..525db10a 100644 --- a/betterproto2/pyproject.toml +++ b/betterproto2/pyproject.toml @@ -47,7 +47,7 @@ test = [ [tool.uv] package = true -default-groups = "all" +default-groups = ["all"] [build-system] requires = ["hatchling"] diff --git a/betterproto2_compiler/pyproject.toml b/betterproto2_compiler/pyproject.toml index ab3804a2..c313d313 100644 --- a/betterproto2_compiler/pyproject.toml +++ b/betterproto2_compiler/pyproject.toml @@ -44,7 +44,7 @@ test = [ [tool.uv] package = true -default-groups = "all" +default-groups = ["all"] # [tool.hatch.build.targets.sdist] # include = ["src/betterproto2_compiler"] From 485954e6878c80028502faa6168bfd8bf8e95656 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Wed, 18 Jun 2025 12:11:09 -0400 Subject: [PATCH 11/17] Revert "fix default groups" This reverts commit a85929c273fc58997933573fe670a7e3b1011ffe. --- betterproto2/pyproject.toml | 2 +- betterproto2_compiler/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/betterproto2/pyproject.toml b/betterproto2/pyproject.toml index 525db10a..65798b8e 100644 --- a/betterproto2/pyproject.toml +++ b/betterproto2/pyproject.toml @@ -47,7 +47,7 @@ test = [ [tool.uv] package = true -default-groups = ["all"] +default-groups = "all" [build-system] requires = ["hatchling"] diff --git a/betterproto2_compiler/pyproject.toml b/betterproto2_compiler/pyproject.toml index c313d313..ab3804a2 100644 --- a/betterproto2_compiler/pyproject.toml +++ b/betterproto2_compiler/pyproject.toml @@ -44,7 +44,7 @@ test = [ [tool.uv] package = true -default-groups = ["all"] +default-groups = "all" # [tool.hatch.build.targets.sdist] # include = ["src/betterproto2_compiler"] From 5dd0efc45e4124663d7443cd853266f013339e2d Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Wed, 18 Jun 2025 18:44:37 -0400 Subject: [PATCH 12/17] better descriptor --- betterproto2/tests/grpc/test_grpclib_reflection.py | 8 ++++++++ .../src/betterproto2_compiler/plugin/models.py | 3 ++- .../src/betterproto2_compiler/plugin/parser.py | 9 +++++++++ .../src/betterproto2_compiler/templates/header.py.j2 | 7 ++++++- 4 files changed, 25 insertions(+), 2 deletions(-) diff --git a/betterproto2/tests/grpc/test_grpclib_reflection.py b/betterproto2/tests/grpc/test_grpclib_reflection.py index d5334aca..3680cf7a 100644 --- a/betterproto2/tests/grpc/test_grpclib_reflection.py +++ b/betterproto2/tests/grpc/test_grpclib_reflection.py @@ -4,6 +4,8 @@ import pytest from google.protobuf import descriptor_pb2 from grpclib.reflection.service import ServerReflection +from grpclib.reflection.v1.reflection_grpc import ServerReflectionBase as ServerReflectionBaseV1 +from grpclib.reflection.v1alpha.reflection_grpc import ServerReflectionBase as ServerReflectionBaseV1Alpha from grpclib.testing import ChannelFor from tests.output_betterproto.example_service import TestBase @@ -14,6 +16,7 @@ ServerReflectionStub, ServiceResponse, ) +from tests.output_betterproto_descriptor.google_proto_descriptor_pool import default_google_proto_descriptor_pool class TestService(TestBase): @@ -50,6 +53,11 @@ async def __anext__(self) -> T: async def test_grpclib_reflection(): service = TestService() services = ServerReflection.extend([service]) + for service in services: + # This won't be needed once https://github.com/vmagamedov/grpclib/pull/204 is in. + if isinstance(service, ServerReflectionBaseV1Alpha | ServerReflectionBaseV1): + service._pool = default_google_proto_descriptor_pool + async with ChannelFor(services) as channel: requests = AsyncIterableQueue[ServerReflectionRequest]() responses = ServerReflectionStub(channel).server_reflection_info(requests) diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py index f92cb56a..99106bf4 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py @@ -259,7 +259,8 @@ def descriptors(self): """ return "\n".join( [ - f"{self.get_descriptor_name(f)} = _descriptor_pool.Default().AddSerializedFile({f.SerializeToString()})" + f"{self.get_descriptor_name(f)} = " + + f"default_google_proto_descriptor_pool.AddSerializedFile({f.SerializeToString()})" for f in self.input_files ] ) diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py index 085b3c9d..de71c18c 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py @@ -169,6 +169,15 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: ) ) + if settings.google_protobuf_descriptors: + response.file.append( + CodeGeneratorResponseFile( + name="google_proto_descriptor_pool.py", + content="from google.protobuf import descriptor_pool\n\n" + + "default_google_proto_descriptor_pool = descriptor_pool.DescriptorPool()\n", + ) + ) + for output_package_name in sorted(output_paths.union(init_files)): print(f"Writing {output_package_name}", file=sys.stderr) diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 index 67ec43f7..892d30d5 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 @@ -32,14 +32,19 @@ import betterproto2 from betterproto2.grpc.grpclib_server import ServiceBase import grpc import grpclib -from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf.descriptor import Descriptor, EnumDescriptor {# Import the message pool of the generated code. #} {% if output_file.package %} from {{ "." * output_file.package.count(".") }}..message_pool import default_message_pool +{% if output_file.settings.google_protobuf_descriptors %} +from {{ "." * output_file.package.count(".") }}..google_proto_descriptor_pool import default_google_proto_descriptor_pool +{% endif %} {% else %} from .message_pool import default_message_pool +{% if output_file.settings.google_protobuf_descriptors %} +from .google_proto_descriptor_pool import default_google_proto_descriptor_pool +{% endif %} {% endif %} if TYPE_CHECKING: From 2cede42157ff373f9dc2a168b97dcc30c195f8cf Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Wed, 18 Jun 2025 18:50:39 -0400 Subject: [PATCH 13/17] update readme --- betterproto2/docs/descriptors.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/betterproto2/docs/descriptors.md b/betterproto2/docs/descriptors.md index 38439300..de521dee 100644 --- a/betterproto2/docs/descriptors.md +++ b/betterproto2/docs/descriptors.md @@ -3,3 +3,9 @@ Google's protoc plugin for Python generated DESCRIPTOR fields that enable reflection capabilities in many libraries (e.g. grpc, grpclib, mcap). By default, betterproto2 doesn't generate these as it introduces a dependency on `protobuf`. If you're okay with this dependency and want to generate DESCRIPTORs, use the compiler option `python_betterproto2_opt=google_protobuf_descriptors`. + + +## grpclib Reflection + +In order to properly use reflection right now, you will need to modify the `DescriptorPool` that is used by grpclib's `ServerReflection`. To do so, take a look at the use of `ServerReflection.extend` in the `test_grpclib_reflection` test in https://github.com/vmagamedov/grpclib/blob/master/tests/grpc/test_grpclib_reflection.py + In the future, once https://github.com/vmagamedov/grpclib/pull/204 is merged, you will be able to pass the `default_google_proto_descriptor_pool` into the `ServerReflection.extend` class method. From aa3af0ccb7f5b49d85f11a475b9bd427b6ae9630 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Wed, 25 Jun 2025 08:26:02 -0400 Subject: [PATCH 14/17] pr changes --- betterproto2/mkdocs.yml | 1 + betterproto2/pyproject.toml | 4 +- betterproto2/tests/test_deprecated.py | 7 +- betterproto2/uv.lock | 10 ++- .../compile/importing.py | 2 - .../betterproto2_compiler/plugin/models.py | 71 ++----------------- .../betterproto2_compiler/plugin/parser.py | 15 ++-- .../templates/header.py.j2 | 11 --- .../templates/template.py.j2 | 5 ++ .../example_service/example_service.proto | 4 ++ 10 files changed, 40 insertions(+), 90 deletions(-) diff --git a/betterproto2/mkdocs.yml b/betterproto2/mkdocs.yml index a8c90113..9d5b7289 100644 --- a/betterproto2/mkdocs.yml +++ b/betterproto2/mkdocs.yml @@ -14,6 +14,7 @@ nav: - Clients: tutorial/clients.md - API: api.md - Development: development.md + - Protobuf Descriptors: descriptors.md plugins: diff --git a/betterproto2/pyproject.toml b/betterproto2/pyproject.toml index 65798b8e..516bdb51 100644 --- a/betterproto2/pyproject.toml +++ b/betterproto2/pyproject.toml @@ -22,7 +22,8 @@ Repository = "https://github.com/betterproto/python-betterproto2" grpcio = ["grpcio>=1.72.1"] grpclib = ["grpclib>=0.4.8"] pydantic = ["pydantic>=2.11.5"] -all = ["grpclib>=0.4.8", "grpcio>=1.72.1", "pydantic>=2.11.5"] +protobuf = ["protobuf>=5.29.3"] +all = ["grpclib>=0.4.8", "grpcio>=1.72.1", "pydantic>=2.11.5", "protobuf>=5.29.3"] [dependency-groups] dev = [ @@ -38,7 +39,6 @@ dev = [ test = [ "cachelib>=0.13.0", "poethepoet>=0.34.0", - "protobuf>=5.29.3", "pytest>=8.4.0", "pytest-asyncio>=1.0.0", "pytest-cov>=6.1.1", diff --git a/betterproto2/tests/test_deprecated.py b/betterproto2/tests/test_deprecated.py index d9a81c02..2930f6cf 100644 --- a/betterproto2/tests/test_deprecated.py +++ b/betterproto2/tests/test_deprecated.py @@ -7,6 +7,7 @@ Empty, Message, Test, + TestNested, TestServiceStub, ) @@ -26,12 +27,12 @@ def test_deprecated_message(): assert str(record[0].message) == f"{Message.__name__} is deprecated" -def test_deprecated_nested_message(): +def test_deprecated_nested_message_field(): with pytest.warns(DeprecationWarning) as record: - Message(value="hello") + TestNested(nested_value="hello") assert len(record) == 1 - assert str(record[0].message) == f"{Message.__name__} is deprecated" + assert str(record[0].message) == f"TestNested.nested_value is deprecated" def test_message_with_deprecated_field(message): diff --git a/betterproto2/uv.lock b/betterproto2/uv.lock index 8db85352..fcf92ad3 100644 --- a/betterproto2/uv.lock +++ b/betterproto2/uv.lock @@ -68,6 +68,7 @@ dependencies = [ all = [ { name = "grpcio" }, { name = "grpclib" }, + { name = "protobuf" }, { name = "pydantic" }, ] grpcio = [ @@ -76,6 +77,9 @@ grpcio = [ grpclib = [ { name = "grpclib" }, ] +protobuf = [ + { name = "protobuf" }, +] pydantic = [ { name = "pydantic" }, ] @@ -93,7 +97,6 @@ dev = [ test = [ { name = "cachelib" }, { name = "poethepoet" }, - { name = "protobuf" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, @@ -106,12 +109,14 @@ requires-dist = [ { name = "grpcio", marker = "extra == 'grpcio'", specifier = ">=1.72.1" }, { name = "grpclib", marker = "extra == 'all'", specifier = ">=0.4.8" }, { name = "grpclib", marker = "extra == 'grpclib'", specifier = ">=0.4.8" }, + { name = "protobuf", marker = "extra == 'all'", specifier = ">=5.29.3" }, + { name = "protobuf", marker = "extra == 'protobuf'", specifier = ">=5.29.3" }, { name = "pydantic", marker = "extra == 'all'", specifier = ">=2.11.5" }, { name = "pydantic", marker = "extra == 'pydantic'", specifier = ">=2.11.5" }, { name = "python-dateutil", specifier = ">=2.9.0.post0" }, { name = "typing-extensions", specifier = ">=4.14.0" }, ] -provides-extras = ["grpcio", "grpclib", "pydantic", "all"] +provides-extras = ["grpcio", "grpclib", "pydantic", "protobuf", "all"] [package.metadata.requires-dev] dev = [ @@ -126,7 +131,6 @@ dev = [ test = [ { name = "cachelib", specifier = ">=0.13.0" }, { name = "poethepoet", specifier = ">=0.34.0" }, - { name = "protobuf", specifier = ">=5.29.3" }, { name = "pytest", specifier = ">=8.4.0" }, { name = "pytest-asyncio", specifier = ">=1.0.0" }, { name = "pytest-cov", specifier = ">=6.1.1" }, diff --git a/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py b/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py index 4f6873c3..cb3f3343 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py +++ b/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py @@ -61,7 +61,6 @@ def get_symbol_reference( imports: set, source_package: str, symbol: str, - request: PluginRequestCompiler, import_suffx: str = "", ) -> tuple[str, str | None]: """ @@ -107,7 +106,6 @@ def get_type_reference( imports=imports, source_package=source_package, symbol=py_type, - request=request, ) return ref diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py index 99106bf4..b26feab2 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py @@ -216,35 +216,6 @@ def input_filenames(self) -> list[str]: """ return sorted([f.name for f in self.input_files]) - @property - def dependency_imports(self): - """Proto dependencies as Python imports. - - Returns - ------- - list[str] - Imports for each proto dependency of this package resolved to the output - names, not the input file paths. - """ - - def dep_to_pkg_import(dep: str) -> tuple[str, str]: - for output_name, pkg in self.parent_request.output_packages.items(): - if dep in pkg.input_filenames: - ref, ref_import = get_symbol_reference( - package=self.package, - imports=self.imports_end, - source_package=output_name, - request=self.parent_request, - symbol="_COMPILER_VERSION", - import_suffx="_dep", - ) - - # import and check compiler version for safety and to avoid this import being removed. - return (ref, f"{ref_import}\nbetterproto2.check_compiler_version({ref})") - raise ValueError(f"cannot find which output package {dep} belongs to") - - return dict([dep_to_pkg_import(dep) for dep in self.package_proto_obj.dependency]).values() - def get_descriptor_name(self, source_file: FileDescriptorProto): return f"{source_file.name.replace('/', '_').replace('.', '_').upper()}_DESCRIPTOR" @@ -272,6 +243,7 @@ class MessageCompiler(ProtoContentBase): output_file: OutputTemplate proto_obj: DescriptorProto + prefixed_proto_name: str fields: list["FieldCompiler"] = field(default_factory=list) oneofs: list["OneofCompiler"] = field(default_factory=list) builtins_types: set[str] = field(default_factory=set) @@ -280,13 +252,9 @@ class MessageCompiler(ProtoContentBase): def proto_name(self) -> str: return self.proto_obj.name - @property - def prefixed_proto_name(self) -> str: - return self.proto_obj.prefixed_name - @property def py_name(self) -> str: - return pythonize_class_name(self.proto_obj.prefixed_name) + return pythonize_class_name(self.prefixed_proto_name) @property def deprecated(self) -> bool: @@ -296,7 +264,7 @@ def deprecated(self) -> bool: def deprecated_fields(self) -> Iterator[str]: for f in self.fields: if f.deprecated: - yield f.proto_name + yield f.py_name @property def has_deprecated_fields(self) -> bool: @@ -320,7 +288,7 @@ def custom_methods(self) -> list[str]: return methods_source @property - def descriptor_name(self): + def descriptor_name(self) -> str: """Google protobuf library descriptor name. Returns @@ -330,17 +298,6 @@ def descriptor_name(self): """ return self.output_file.get_descriptor_name(self.source_file) - @property - def descriptor(self): - """Google protobuf library descriptor. - - Returns - ------- - str - A binary string of the message's proto descriptor. - """ - return self.proto_obj.SerializeToString() - def is_map(proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto) -> bool: """True if proto_field_obj is a map, otherwise False.""" @@ -637,6 +594,7 @@ class EnumDefinitionCompiler(ProtoContentBase): output_file: OutputTemplate proto_obj: EnumDescriptorProto + prefixed_proto_name: str entries: list["EnumDefinitionCompiler.EnumEntry"] = field(default_factory=list) @dataclass(unsafe_hash=True, kw_only=True) @@ -662,20 +620,16 @@ def __post_init__(self) -> None: def proto_name(self) -> str: return self.proto_obj.name - @property - def prefixed_proto_name(self) -> str: - return self.proto_obj.prefixed_name - @property def py_name(self) -> str: - return pythonize_class_name(self.proto_obj.prefixed_name) + return pythonize_class_name(self.prefixed_proto_name) @property def deprecated(self) -> bool: return bool(self.proto_obj.options and self.proto_obj.options.deprecated) @property - def descriptor_name(self): + def descriptor_name(self) -> str: """Google protobuf library descriptor name. Returns @@ -685,17 +639,6 @@ def descriptor_name(self): """ return self.output_file.get_descriptor_name(self.source_file) - @property - def descriptor(self): - """Google protobuf library descriptor. - - Returns - ------- - str - A binary string of the enum's proto descriptor. - """ - return self.proto_obj.SerializeToString() - @dataclass(kw_only=True) class ServiceCompiler(ProtoContentBase): diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py index de71c18c..1d609a74 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/parser.py @@ -35,20 +35,21 @@ def traverse( proto_file: FileDescriptorProto, -) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]: +) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int], str], None, None]: # Todo: Keep information about nested hierarchy def _traverse( path: list[int], items: list[EnumDescriptorProto] | list[DescriptorProto], prefix: str = "", - ) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int]], None, None]: + ) -> Generator[tuple[EnumDescriptorProto | DescriptorProto, list[int], str], None, None]: for i, item in enumerate(items): # Adjust the name since we flatten the hierarchy. should_rename = not isinstance(item, DescriptorProto) or not item.options or not item.options.map_entry # Record prefixed name but *do not* mutate original file. - item.prefixed_name = next_prefix = f"{prefix}.{item.name}" if prefix and should_rename else item.name - yield item, [*path, i] + # We use this prefixed name to create pythonized names. + prefixed_name = next_prefix = f"{prefix}.{item.name}" if prefix and should_rename else item.name + yield item, [*path, i], prefixed_name if isinstance(item, DescriptorProto): # Get nested types. @@ -110,12 +111,13 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: # get the references to input/output messages for each service for output_package_name, output_package in request_data.output_packages.items(): for proto_input_file in output_package.input_files: - for item, path in traverse(proto_input_file): + for item, path, prefixed_proto_name in traverse(proto_input_file): read_protobuf_type( source_file=proto_input_file, item=item, path=path, output_package=output_package, + prefixed_proto_name=prefixed_proto_name, ) # Read Services @@ -189,6 +191,7 @@ def read_protobuf_type( path: list[int], source_file: "FileDescriptorProto", output_package: OutputTemplate, + prefixed_proto_name: str, ) -> None: if isinstance(item, DescriptorProto): if item.options and item.options.map_entry: @@ -198,6 +201,7 @@ def read_protobuf_type( message_data = MessageCompiler( source_file=source_file, output_file=output_package, + prefixed_proto_name=prefixed_proto_name, proto_obj=item, path=path, ) @@ -253,6 +257,7 @@ def read_protobuf_type( enum = EnumDefinitionCompiler( source_file=source_file, output_file=output_package, + prefixed_proto_name=prefixed_proto_name, proto_obj=item, path=path, ) diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 index 892d30d5..aa348bf7 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/header.py.j2 @@ -54,14 +54,3 @@ if TYPE_CHECKING: _COMPILER_VERSION="{{ version }}" betterproto2.check_compiler_version(_COMPILER_VERSION) - -{% if output_file.settings.google_protobuf_descriptors %} - -{% for import in output_file.dependency_imports %} -{{ import.rstrip() }} -{{ "\n" }} -{% endfor %} - -{# Add descriptors to Google protobuf's default pool to be more drop-in compatible with other libraries. #} -{{ output_file.descriptors }} -{% endif %} diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 index 2f705cdf..09562fb5 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2 @@ -116,6 +116,11 @@ default_message_pool.register_message("{{ output_file.package }}", "{{ message.p {{ i }} {% endfor %} +{% if output_file.settings.google_protobuf_descriptors %} +{# Add descriptors to Google protobuf's default pool to be more drop-in compatible with other libraries. #} +{{ output_file.descriptors }} +{% endif %} + {% if output_file.settings.server_generation == "async" %} {% for _, service in output_file.services|dictsort(by="key") %} class {{ (service.py_name + "Base") | add_to_all }}(ServiceBase): diff --git a/betterproto2_compiler/tests/inputs/example_service/example_service.proto b/betterproto2_compiler/tests/inputs/example_service/example_service.proto index 96455cc3..4ef60236 100644 --- a/betterproto2_compiler/tests/inputs/example_service/example_service.proto +++ b/betterproto2_compiler/tests/inputs/example_service/example_service.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package example_service; +import "google/protobuf/struct.proto"; + service Test { rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse); rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse); @@ -12,9 +14,11 @@ service Test { message ExampleRequest { string example_string = 1; int64 example_integer = 2; + google.protobuf.Struct example_struct = 3; } message ExampleResponse { string example_string = 1; int64 example_integer = 2; + google.protobuf.Struct example_struct = 3; } From deb851f663a00453f20c932e1c8e1320d722f6c3 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Wed, 25 Jun 2025 08:32:16 -0400 Subject: [PATCH 15/17] lint --- .../src/betterproto2_compiler/plugin/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py index b26feab2..92470f1b 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py @@ -34,7 +34,7 @@ from betterproto2 import unwrap -from betterproto2_compiler.compile.importing import get_symbol_reference, get_type_reference, parse_source_type_name +from betterproto2_compiler.compile.importing import get_type_reference, parse_source_type_name from betterproto2_compiler.compile.naming import ( pythonize_class_name, pythonize_enum_member_name, From 9da89da2982852754f97ac3911b7dbbbc5c6f8ba Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Wed, 25 Jun 2025 09:42:40 -0400 Subject: [PATCH 16/17] remove import_suffix --- .../compile/importing.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py b/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py index cb3f3343..6b7b8c08 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py +++ b/betterproto2_compiler/src/betterproto2_compiler/compile/importing.py @@ -61,7 +61,6 @@ def get_symbol_reference( imports: set, source_package: str, symbol: str, - import_suffx: str = "", ) -> tuple[str, str | None]: """ Return a Python symbol within a proto package. Adds the import if @@ -74,12 +73,12 @@ def get_symbol_reference( return (reference_sibling(symbol), None) if py_package[: len(current_package)] == current_package: - return reference_descendent(current_package, imports, py_package, symbol, import_suffx) + return reference_descendent(current_package, imports, py_package, symbol) if current_package[: len(py_package)] == py_package: - return reference_ancestor(current_package, imports, py_package, symbol, import_suffx) + return reference_ancestor(current_package, imports, py_package, symbol) - return reference_cousin(current_package, imports, py_package, symbol, import_suffx) + return reference_cousin(current_package, imports, py_package, symbol) def get_type_reference( @@ -129,7 +128,7 @@ def reference_sibling(py_type: str) -> str: def reference_descendent( - current_package: list[str], imports: set[str], py_package: list[str], py_type: str, import_suffix: str = "" + current_package: list[str], imports: set[str], py_package: list[str], py_type: str ) -> tuple[str, str]: """ Returns a reference to a python type in a package that is a descendent of the @@ -140,7 +139,7 @@ def reference_descendent( string_from = ".".join(importing_descendent[:-1]) string_import = importing_descendent[-1] if string_from: - string_alias = f"{'_'.join(importing_descendent)}{import_suffix}" + string_alias = f"{'_'.join(importing_descendent)}" import_to_add = f"from .{string_from} import {string_import} as {string_alias}" imports.add(import_to_add) return (f"{string_alias}.{py_type}", import_to_add) @@ -151,7 +150,7 @@ def reference_descendent( def reference_ancestor( - current_package: list[str], imports: set[str], py_package: list[str], py_type: str, import_suffix: str = "" + current_package: list[str], imports: set[str], py_package: list[str], py_type: str ) -> tuple[str, str]: """ Returns a reference to a python type in a package which is an ancestor to the @@ -163,20 +162,20 @@ def reference_ancestor( distance_up = len(current_package) - len(py_package) if py_package: string_import = py_package[-1] - string_alias = f"_{'_' * distance_up}{string_import}__{import_suffix}" + string_alias = f"_{'_' * distance_up}{string_import}__" string_from = f"..{'.' * distance_up}" import_to_add = f"from {string_from} import {string_import} as {string_alias}" imports.add(import_to_add) return (f"{string_alias}.{py_type}", import_to_add) else: - string_alias = f"{'_' * distance_up}{py_type}__{import_suffix}" + string_alias = f"{'_' * distance_up}{py_type}__" import_to_add = f"from .{'.' * distance_up} import {py_type} as {string_alias}" imports.add(import_to_add) return (string_alias, import_to_add) def reference_cousin( - current_package: list[str], imports: set[str], py_package: list[str], py_type: str, import_suffix: str = "" + current_package: list[str], imports: set[str], py_package: list[str], py_type: str ) -> tuple[str, str]: """ Returns a reference to a python type in a package that is not descendent, ancestor @@ -192,7 +191,6 @@ def reference_cousin( f"{'_' * distance_up}" + "__".join([safe_snake_case(name) for name in py_package[len(shared_ancestry) :]]) + "__" - + import_suffix ) import_to_add = f"from {string_from} import {string_import} as {string_alias}" imports.add(import_to_add) From 38c865b69bc1256ea9507cd9db2eeac7182b7382 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Wed, 25 Jun 2025 10:27:52 -0400 Subject: [PATCH 17/17] remove source code info from descriptors --- .../betterproto2_compiler/plugin/models.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py index 92470f1b..5b339735 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py @@ -53,6 +53,7 @@ MethodDescriptorProto, OneofDescriptorProto, ServiceDescriptorProto, + SourceCodeInfo, ) from betterproto2_compiler.lib.google.protobuf.compiler import CodeGeneratorRequest from betterproto2_compiler.settings import Settings @@ -228,13 +229,20 @@ def descriptors(self): str A list of pool registrations for proto descriptors. """ - return "\n".join( - [ - f"{self.get_descriptor_name(f)} = " - + f"default_google_proto_descriptor_pool.AddSerializedFile({f.SerializeToString()})" - for f in self.input_files - ] - ) + descriptors: list[str] = [] + + for f in self.input_files: + # Remove the source_code_info field since it is not needed at runtime. + source_code_info: SourceCodeInfo | None = f.source_code_info + f.source_code_info = None + + descriptors.append( + f"{self.get_descriptor_name(f)} = default_google_proto_descriptor_pool.AddSerializedFile({bytes(f)})" + ) + + f.source_code_info = source_code_info + + return "\n".join(descriptors) @dataclass(kw_only=True)