Skip to content

Commit 11fbfbb

Browse files
[feat] Support google.protobuf.Struct (#10)
This commit makes the denizens of `google/protobuf/struct.proto` be represented as native Python objects (primitives, sequences, and maps).
1 parent 2f40188 commit 11fbfbb

File tree

9 files changed

+408
-18
lines changed

9 files changed

+408
-18
lines changed

packages/proto-plus/proto/marshal/collections/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .map import MapComposite
15+
from .maps import MapComposite
1616
from .repeated import Repeated
1717
from .repeated import RepeatedComposite
1818

File renamed without changes.

packages/proto-plus/proto/marshal/collections/repeated.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ def _pb_type(self):
102102
canary = copy.deepcopy(self.pb).add()
103103
return type(canary)
104104

105+
def __eq__(self, other):
106+
if super().__eq__(other):
107+
return True
108+
return tuple([i for i in self]) == tuple(other)
109+
105110
def __getitem__(self, key):
106111
return self._marshal.to_python(self._pb_type, self.pb[key])
107112

packages/proto-plus/proto/marshal/marshal.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
from google.protobuf import message
1818
from google.protobuf import duration_pb2
1919
from google.protobuf import timestamp_pb2
20+
from google.protobuf import struct_pb2
2021
from google.protobuf import wrappers_pb2
2122

2223
from proto.marshal import compat
2324
from proto.marshal.collections import MapComposite
2425
from proto.marshal.collections import Repeated
2526
from proto.marshal.collections import RepeatedComposite
2627
from proto.marshal.rules import dates
28+
from proto.marshal.rules import struct
2729
from proto.marshal.rules import wrappers
2830

2931

@@ -130,6 +132,15 @@ def reset(self):
130132
self.register(wrappers_pb2.UInt32Value, wrappers.UInt32ValueRule())
131133
self.register(wrappers_pb2.UInt64Value, wrappers.UInt64ValueRule())
132134

135+
# Register the google.protobuf.Struct wrappers.
136+
#
137+
# These are aware of the marshal that created them, because they
138+
# create RepeatedComposite and MapComposite instances directly and
139+
# need to pass the marshal to them.
140+
self.register(struct_pb2.Value, struct.ValueRule(marshal=self))
141+
self.register(struct_pb2.ListValue, struct.ListValueRule(marshal=self))
142+
self.register(struct_pb2.Struct, struct.StructRule(marshal=self))
143+
133144
def to_python(self, proto_type, value, *, absent: bool = None):
134145
# Internal protobuf has its own special type for lists of values.
135146
# Return a view around it that implements MutableSequence.
@@ -147,14 +158,21 @@ def to_python(self, proto_type, value, *, absent: bool = None):
147158
return rule.to_python(value, absent=absent)
148159

149160
def to_proto(self, proto_type, value, *, strict: bool = False):
150-
# For our repeated and map view objects, simply return the
151-
# underlying pb.
152-
if isinstance(value, (Repeated, MapComposite)):
153-
return value.pb
154-
155-
# Convert lists and tuples recursively.
156-
if isinstance(value, (list, tuple)):
157-
return type(value)([self.to_proto(proto_type, i) for i in value])
161+
# The protos in google/protobuf/struct.proto are exceptional cases,
162+
# because they can and should represent themselves as lists and dicts.
163+
# These cases are handled in their rule classes.
164+
if proto_type not in (struct_pb2.Value,
165+
struct_pb2.ListValue, struct_pb2.Struct):
166+
# For our repeated and map view objects, simply return the
167+
# underlying pb.
168+
if isinstance(value, (Repeated, MapComposite)):
169+
return value.pb
170+
171+
# Convert lists and tuples recursively.
172+
if isinstance(value, (list, tuple)):
173+
return type(value)(
174+
[self.to_proto(proto_type, i) for i in value],
175+
)
158176

159177
# Convert dictionaries recursively when the proto type is a map.
160178
# This is slightly more complicated than converting a list or tuple
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 2018 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import collections.abc
16+
17+
from google.protobuf import struct_pb2
18+
19+
from proto.marshal.collections import maps
20+
from proto.marshal.collections import repeated
21+
22+
23+
class ValueRule:
24+
"""A rule to marshal between google.protobuf.Value and Python values."""
25+
26+
def __init__(self, *, marshal):
27+
self._marshal = marshal
28+
29+
def to_python(self, value, *, absent: bool = None):
30+
"""Coerce the given value to the appropriate Python type.
31+
32+
Note that setting ``null_value`` is distinct from not setting
33+
a value, and the absent value will raise an exception.
34+
"""
35+
kind = value.WhichOneof('kind')
36+
if kind == 'null_value':
37+
return None
38+
if kind == 'bool_value':
39+
return bool(value.bool_value)
40+
if kind == 'number_value':
41+
return float(value.number_value)
42+
if kind == 'string_value':
43+
return str(value.string_value)
44+
if kind == 'struct_value':
45+
return self._marshal.to_python(
46+
struct_pb2.Struct,
47+
value.struct_value,
48+
absent=False,
49+
)
50+
if kind == 'list_value':
51+
return self._marshal.to_python(
52+
struct_pb2.ListValue,
53+
value.list_value,
54+
absent=False,
55+
)
56+
raise AttributeError
57+
58+
def to_proto(self, value) -> struct_pb2.Value:
59+
"""Return a protobuf Value object representing this value."""
60+
if isinstance(value, struct_pb2.Value):
61+
return value
62+
if value is None:
63+
return struct_pb2.Value(null_value=0)
64+
if isinstance(value, bool):
65+
return struct_pb2.Value(bool_value=value)
66+
if isinstance(value, (int, float)):
67+
return struct_pb2.Value(number_value=float(value))
68+
if isinstance(value, str):
69+
return struct_pb2.Value(string_value=value)
70+
if isinstance(value, collections.abc.Sequence):
71+
return struct_pb2.Value(
72+
list_value=self._marshal.to_proto(struct_pb2.ListValue, value),
73+
)
74+
if isinstance(value, collections.abc.Mapping):
75+
return struct_pb2.Value(
76+
struct_value=self._marshal.to_proto(struct_pb2.Struct, value),
77+
)
78+
raise ValueError('Unable to coerce value: %r' % value)
79+
80+
81+
class ListValueRule:
82+
"""A rule translating google.protobuf.ListValue and list-like objects."""
83+
84+
def __init__(self, *, marshal):
85+
self._marshal = marshal
86+
87+
def to_python(self, value, *, absent: bool = None):
88+
"""Coerce the given value to a Python sequence."""
89+
return repeated.RepeatedComposite(value.values, marshal=self._marshal)
90+
91+
def to_proto(self, value) -> struct_pb2.ListValue:
92+
# We got a proto, or else something we sent originally.
93+
# Preserve the instance we have.
94+
if isinstance(value, struct_pb2.ListValue):
95+
return value
96+
if isinstance(value, repeated.RepeatedComposite):
97+
return struct_pb2.ListValue(values=[v for v in value.pb])
98+
99+
# We got a list (or something list-like); convert it.
100+
return struct_pb2.ListValue(values=[
101+
self._marshal.to_proto(struct_pb2.Value, v) for v in value
102+
])
103+
104+
105+
class StructRule:
106+
"""A rule translating google.protobuf.Struct and dict-like objects."""
107+
108+
def __init__(self, *, marshal):
109+
self._marshal = marshal
110+
111+
def to_python(self, value, *, absent: bool = None):
112+
"""Coerce the given value to a Python mapping."""
113+
return maps.MapComposite(value.fields, marshal=self._marshal)
114+
115+
def to_proto(self, value) -> struct_pb2.Struct:
116+
# We got a proto, or else something we sent originally.
117+
# Preserve the instance we have.
118+
if isinstance(value, struct_pb2.Struct):
119+
return value
120+
if isinstance(value, maps.MapComposite):
121+
return struct_pb2.Struct(
122+
fields={k: v for k, v in value.pb.items()},
123+
)
124+
125+
# We got a dict (or something dict-like); convert it.
126+
answer = struct_pb2.Struct(fields={
127+
k: self._marshal.to_proto(struct_pb2.Value, v)
128+
for k, v in value.items()
129+
})
130+
return answer

packages/proto-plus/proto/message.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def __setattr__(self, key, value):
489489
pb_type = self._meta.fields[key].pb_type
490490
pb_value = marshal.to_proto(pb_type, value)
491491

492-
# We *always* clear the existing field.
492+
# Clear the existing field.
493493
# This is the only way to successfully write nested falsy values,
494494
# because otherwise MergeFrom will no-op on them.
495495
self._pb.ClearField(key)

packages/proto-plus/tests/conftest.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,10 @@ def pytest_runtest_setup(item):
5757
module = getattr(item.module, name)
5858
pool.AddSerializedFile(module.DESCRIPTOR.serialized_pb)
5959
fd = pool.FindFileByName(module.DESCRIPTOR.name)
60-
for message_name, descriptor in fd.message_types_by_name.items():
61-
new_message = reflection.GeneratedProtocolMessageType(
62-
message_name,
63-
(message.Message,),
64-
{'DESCRIPTOR': descriptor, '__module__': None},
65-
)
66-
sym_db.RegisterMessage(new_message)
67-
setattr(module, message_name, new_message)
60+
61+
# Register all the messages to the symbol database and the
62+
# module. Do this recursively if there are nested messages.
63+
_register_messages(module, fd.message_types_by_name, sym_db)
6864

6965
# Track which modules had new message classes loaded.
7066
# This is used below to wire the new classes into the marshal.
@@ -74,10 +70,25 @@ def pytest_runtest_setup(item):
7470
# then reload the appropriate modules so the marshal is using the new ones.
7571
if 'wrappers_pb2' in reloaded:
7672
imp.reload(rules.wrappers)
73+
if 'struct_pb2' in reloaded:
74+
imp.reload(rules.struct)
7775
if reloaded.intersection({'timestamp_pb2', 'duration_pb2'}):
7876
imp.reload(rules.dates)
7977

8078

8179
def pytest_runtest_teardown(item):
8280
Marshal._instances.clear()
8381
[i.stop() for i in item._mocks]
82+
83+
84+
def _register_messages(scope, iterable, sym_db):
85+
"""Create and register messages from the file descriptor."""
86+
for name, descriptor in iterable.items():
87+
new_msg = reflection.GeneratedProtocolMessageType(
88+
name,
89+
(message.Message,),
90+
{'DESCRIPTOR': descriptor, '__module__': None},
91+
)
92+
sym_db.RegisterMessage(new_msg)
93+
setattr(scope, name, new_msg)
94+
_register_messages(new_msg, descriptor.nested_types_by_name, sym_db)

packages/proto-plus/tests/test_fields_repeated_composite.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,21 @@ class Baz(proto.Message):
3131

3232
baz = Baz(foos=[Foo(bar=42)])
3333
assert len(baz.foos) == 1
34+
assert baz.foos == baz.foos
3435
assert baz.foos[0].bar == 42
3536

3637

38+
def test_repeated_composite_equality():
39+
class Foo(proto.Message):
40+
bar = proto.Field(proto.INT32, number=1)
41+
42+
class Baz(proto.Message):
43+
foos = proto.RepeatedField(proto.MESSAGE, message=Foo, number=1)
44+
45+
baz = Baz(foos=[Foo(bar=42)])
46+
assert baz.foos == baz.foos
47+
48+
3749
def test_repeated_composite_init_struct():
3850
class Foo(proto.Message):
3951
bar = proto.Field(proto.INT32, number=1)

0 commit comments

Comments
 (0)