From cef3abc696e80e70135656d937e271529743ddae Mon Sep 17 00:00:00 2001 From: zaki Date: Mon, 27 May 2024 22:47:20 +0800 Subject: [PATCH 01/32] feat: do something about service reference --- dubbo/__init__.py | 15 ++++ dubbo/client/__init__.py | 15 ++++ dubbo/client/tri/__init__.py | 15 ++++ dubbo/client/tri/client_call.py | 38 +++++++++ dubbo/common/__init__.py | 15 ++++ dubbo/common/compression/__init__.py | 15 ++++ dubbo/common/compression/compression.py | 37 +++++++++ dubbo/common/compression/gzip.py | 39 +++++++++ dubbo/common/config/__init__.py | 15 ++++ dubbo/common/extensions/__init__.py | 15 ++++ dubbo/common/extensions/extension.py | 49 +++++++++++ dubbo/common/extensions/protocols_loader.py | 64 ++++++++++++++ dubbo/common/node.py | 42 ++++++++++ dubbo/common/url.py | 92 +++++++++++++++++++++ dubbo/config/__init__.py | 15 ++++ dubbo/config/application_config.py | 34 ++++++++ dubbo/config/protocol_config.py | 44 ++++++++++ dubbo/config/reference_config.py | 37 +++++++++ dubbo/imports/__init__.py | 15 ++++ dubbo/imports/imports.py | 23 ++++++ dubbo/protocols/__init__.py | 15 ++++ dubbo/protocols/invocation.py | 18 ++++ dubbo/protocols/invoker.py | 35 ++++++++ dubbo/protocols/protocol.py | 39 +++++++++ dubbo/protocols/triple/__init__.py | 15 ++++ dubbo/protocols/triple/triple_protocol.py | 31 +++++++ dubbo/pydubbo.py | 17 ++++ tests/__init__.py | 15 ++++ tests/common/__init__.py | 15 ++++ tests/common/url_test.py | 78 +++++++++++++++++ 30 files changed, 912 insertions(+) create mode 100644 dubbo/__init__.py create mode 100644 dubbo/client/__init__.py create mode 100644 dubbo/client/tri/__init__.py create mode 100644 dubbo/client/tri/client_call.py create mode 100644 dubbo/common/__init__.py create mode 100644 dubbo/common/compression/__init__.py create mode 100644 dubbo/common/compression/compression.py create mode 100644 dubbo/common/compression/gzip.py create mode 100644 dubbo/common/config/__init__.py create mode 100644 dubbo/common/extensions/__init__.py create mode 100644 dubbo/common/extensions/extension.py create mode 100644 dubbo/common/extensions/protocols_loader.py create mode 100644 dubbo/common/node.py create mode 100644 dubbo/common/url.py create mode 100644 dubbo/config/__init__.py create mode 100644 dubbo/config/application_config.py create mode 100644 dubbo/config/protocol_config.py create mode 100644 dubbo/config/reference_config.py create mode 100644 dubbo/imports/__init__.py create mode 100644 dubbo/imports/imports.py create mode 100644 dubbo/protocols/__init__.py create mode 100644 dubbo/protocols/invocation.py create mode 100644 dubbo/protocols/invoker.py create mode 100644 dubbo/protocols/protocol.py create mode 100644 dubbo/protocols/triple/__init__.py create mode 100644 dubbo/protocols/triple/triple_protocol.py create mode 100644 dubbo/pydubbo.py create mode 100644 tests/__init__.py create mode 100644 tests/common/__init__.py create mode 100644 tests/common/url_test.py diff --git a/dubbo/__init__.py b/dubbo/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/client/__init__.py b/dubbo/client/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/client/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/client/tri/__init__.py b/dubbo/client/tri/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/client/tri/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/client/tri/client_call.py b/dubbo/client/tri/client_call.py new file mode 100644 index 0000000..ee17f7b --- /dev/null +++ b/dubbo/client/tri/client_call.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + + +class UnaryUnaryMultiCallable(abc.ABC): + """Affords invoking a unary-unary RPC from client-side.""" + + @abc.abstractmethod + def __call__( + self, + request, + timeout=None, + compression=None + ): + """ + Synchronously invokes the underlying RPC. + Args: + request: The request value for the RPC. + timeout: An optional duration of time in seconds to allow for the RPC. + compression: An element of dubbo.common.compression, e.g. 'gzip'. + """ + + raise NotImplementedError() diff --git a/dubbo/common/__init__.py b/dubbo/common/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/common/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/common/compression/__init__.py b/dubbo/common/compression/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/common/compression/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/common/compression/compression.py b/dubbo/common/compression/compression.py new file mode 100644 index 0000000..ed1569d --- /dev/null +++ b/dubbo/common/compression/compression.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + + +class Compression(abc.ABC): + """Compression interface.""" + + def compress(self, data: bytes) -> bytes: + """ + Compress data. + :param data: data to be compressed. + :return: compressed data. + """ + raise NotImplementedError("Method 'compress' is not implemented.") + + def decompress(self, data: bytes) -> bytes: + """ + Decompress data. + :param data: data to be decompressed. + :return: decompressed data. + """ + raise NotImplementedError("Method 'decompress' is not implemented.") diff --git a/dubbo/common/compression/gzip.py b/dubbo/common/compression/gzip.py new file mode 100644 index 0000000..099fa8a --- /dev/null +++ b/dubbo/common/compression/gzip.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gzip + +from dubbo.common.compression.compression import Compression + + +class GzipCompression(Compression): + """Gzip compression implementation.""" + + def compress(self, data: bytes) -> bytes: + """ + Compress data using gzip. + :param data: data to be compressed. + :return: compressed data. + """ + return gzip.compress(data) + + def decompress(self, data: bytes) -> bytes: + """ + Decompress data using gzip. + :param data: data to be decompressed. + :return: decompressed data. + """ + return gzip.decompress(data) diff --git a/dubbo/common/config/__init__.py b/dubbo/common/config/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/common/config/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/common/extensions/__init__.py b/dubbo/common/extensions/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/common/extensions/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/common/extensions/extension.py b/dubbo/common/extensions/extension.py new file mode 100644 index 0000000..4524516 --- /dev/null +++ b/dubbo/common/extensions/extension.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class ExtensionLoader: + """ + Extension loader Interface. + Any class that implements this interface can be called an extension loader. + """ + + @classmethod + def set(cls, name: str, extension): + """ + Set the extension. + :param name: The name of the extension. + :param extension: The extension. + """ + raise NotImplementedError("Method 'set' is not implemented.") + + @classmethod + def get(cls, name: str): + """ + Get the extension. + :param name: The name of the extension. + :return: The extension. + """ + raise NotImplementedError("Method 'get' is not implemented.") + + @classmethod + def register(cls, name: str): + """ + Register the extension. + This method is a decorator. + :param name: The name of the extension. + """ + raise NotImplementedError("Method 'register' is not implemented.") diff --git a/dubbo/common/extensions/protocols_loader.py b/dubbo/common/extensions/protocols_loader.py new file mode 100644 index 0000000..f37dd1d --- /dev/null +++ b/dubbo/common/extensions/protocols_loader.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dubbo.common.extensions import extension +from dubbo.protocols.protocol import Protocol + + +class ProtocolExtensionLoader(extension.ExtensionLoader): + """ + Protocol extension loader. + """ + # Store the protocol classes. k: name, v: protocol class + __protocols: dict[str, type] = dict() + + @classmethod + def set(cls, name: str, protocol_class: type): + """ + Set the protocol. + :param name: The name of the protocols. + :param protocol_class: The protocol class. + """ + # Check if the protocol_class is a subclass of Protocol. + if not issubclass(protocol_class, Protocol): + raise TypeError(f"Need a subclass of Protocol, but got {protocol_class}") + cls.__protocols[name] = protocol_class + + @classmethod + def get(cls, name) -> Protocol: + """ + Get the protocols. + :param name: The name of the protocols. + :return: The protocol instance. + """ + try: + return cls.__protocols.get(name)() + except KeyError: + raise KeyError(f"Protocol extension not found: {name}") + + @classmethod + def register(cls, name: str): + """ + Register the protocols. + This method is a decorator. + :param name: The name of the protocols. + """ + + def decorator(protocol_class): + cls.set(name, protocol_class) + return protocol_class + + return decorator diff --git a/dubbo/common/node.py b/dubbo/common/node.py new file mode 100644 index 0000000..c75f9f3 --- /dev/null +++ b/dubbo/common/node.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dubbo.common.url import URL + + +class Node: + """ + Node. + """ + + def get_url(self) -> URL: + """ + Get URL. + :return: URL + """ + raise NotImplementedError("Method 'get_url' is not implemented.") + + def is_available(self) -> bool: + """ + Is available. + """ + raise NotImplementedError("Method 'is_available' is not implemented.") + + def destroy(self) -> None: + """ + Destroy + """ + raise NotImplementedError("Method 'destroy' is not implemented.") diff --git a/dubbo/common/url.py b/dubbo/common/url.py new file mode 100644 index 0000000..34e1694 --- /dev/null +++ b/dubbo/common/url.py @@ -0,0 +1,92 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import urllib.parse as ulp + + +class URL: + + def __init__(self, + protocol: str, + host: str, + port: int, + username: str = None, + password: str = None, + path: str = None, + params: dict[str, str] = None + ): + """ + Initialize URL object. + :param protocol: protocols. + :param host: host. + :param port: port. + :param username: username. + :param password: password. + :param path: path. + :param params: parameters. + """ + self.protocol = protocol + self.host = host + self.port = port + self.username = username + if password and not username: + raise ValueError("Password must be set with username.") + self.password = password + self.path = path or '' + self.params = params or {} + + def to_str(self, encoded: bool = False) -> str: + """ + Convert URL object to URL string. + :param encoded: Whether to encode the URL, default is False. + """ + # Set username and password + auth_part = f"{self.username}:{self.password}@" if self.username or self.password else "" + # Set location + netloc = f"{auth_part}{self.host}{self.port}" + query = ulp.urlencode(self.params) + path = self.path + + url_parts = (self.protocol, netloc, path, '', query, '') + url_str = str(ulp.urlunparse(url_parts)) + + if encoded: + url_str = ulp.quote(url_str) + + return url_str + + def __str__(self): + return self.to_str() + + +def parse_url(url: str, encoded: bool = False) -> URL: + """ + Parse URL string to URL object. + :param url: URL string. + :param encoded: Whether the URL is encoded, default is False. + :return: URL + """ + if encoded: + url = ulp.unquote(url) + parsed_url = ulp.urlparse(url) + protocol = parsed_url.scheme + host = parsed_url.hostname + port = parsed_url.port + path = parsed_url.path + params = {k: v[0] for k, v in ulp.parse_qs(parsed_url.query).items()} + username = parsed_url.username or '' + password = parsed_url.password or '' + return URL(protocol, host, port, username, password, path, params) diff --git a/dubbo/config/__init__.py b/dubbo/config/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/config/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/config/application_config.py b/dubbo/config/application_config.py new file mode 100644 index 0000000..ce76327 --- /dev/null +++ b/dubbo/config/application_config.py @@ -0,0 +1,34 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class ApplicationConfig: + """ + Application Config + """ + + def __init__(self): + # name + self.name = '' + # version + self.version = '' + # owner + self.owner = '' + # organization(BU) + self.organization = '' + # architecture, e.g. intl, china + self.architecture = '' + # environment, e.g. dev, test, production + self.environment = '' diff --git a/dubbo/config/protocol_config.py b/dubbo/config/protocol_config.py new file mode 100644 index 0000000..09f09b9 --- /dev/null +++ b/dubbo/config/protocol_config.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class ProtocolConfig: + """ + Protocol Config + """ + + def __init__(self): + # protocol name + self.name = '' + # service ip address + self.host = '' + # service port + self.port = None + # protocol codec + self.codec = '' + # serialization + self.serialization = '' + # charset + self.charset = '' + # ssl + self.ssl = False + # transporter + self.transporter = '' + # server + self.server = '' + # client + self.client = '' + # register + self.register = False diff --git a/dubbo/config/reference_config.py b/dubbo/config/reference_config.py new file mode 100644 index 0000000..b3f0f7c --- /dev/null +++ b/dubbo/config/reference_config.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dubbo.protocols.protocol import Protocol + + +class ReferenceConfig: + """ + ReferenceConfig is the configuration of service consumer. + """ + + def __init__(self): + # A particular Protocol implementation is determined by the protocol attribute in the URL. + self.protocol = None + # A ProxyFactory implementation that will generate a reference service's proxy + self.pxy = None + # The interface proxy reference + self.ref = None + # The invoker of the reference service + self.invoker = None + # The flag whether the ReferenceConfig has been initialized + self.initialized = False + + diff --git a/dubbo/imports/__init__.py b/dubbo/imports/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/imports/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/imports/imports.py b/dubbo/imports/imports.py new file mode 100644 index 0000000..838183c --- /dev/null +++ b/dubbo/imports/imports.py @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module provides a centralized collection of Dubbo SPI implementations. +It simplifies plugin installation using Python's import mechanism. +""" + +# Load Protocol Extension +from dubbo.protocols.triple.triple_protocol import TripleProtocol diff --git a/dubbo/protocols/__init__.py b/dubbo/protocols/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/protocols/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/protocols/invocation.py b/dubbo/protocols/invocation.py new file mode 100644 index 0000000..54a1481 --- /dev/null +++ b/dubbo/protocols/invocation.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class Invocation: + pass diff --git a/dubbo/protocols/invoker.py b/dubbo/protocols/invoker.py new file mode 100644 index 0000000..14c9f29 --- /dev/null +++ b/dubbo/protocols/invoker.py @@ -0,0 +1,35 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dubbo.common.node import Node + + +class Invoker(Node): + """ + Invoker. + """ + + def get_interface(self): + """ + Get service interface. + """ + raise NotImplementedError("Method 'get_interface' is not implemented.") + + def invoke(self): + """ + Invoke. + """ + raise NotImplementedError("Method 'invoke' is not implemented.") diff --git a/dubbo/protocols/protocol.py b/dubbo/protocols/protocol.py new file mode 100644 index 0000000..a6df8da --- /dev/null +++ b/dubbo/protocols/protocol.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dubbo.common.url import URL +from dubbo.protocols.invoker import Invoker + + +class Protocol: + """ + RPC Protocol extension interface, which encapsulates the details of remote invocation. + """ + + def export(self, invoker: Invoker): + """ + Export service for remote invocation + :param invoker: service invoker + """ + raise NotImplementedError("Method 'export' is not implemented.") + + def refer(self, service_type, url: URL): + """ + Refer a remote service. + :param service_type: service class + :param url: URL address for the remote service + """ + raise NotImplementedError("Method 'refer' is not implemented.") diff --git a/dubbo/protocols/triple/__init__.py b/dubbo/protocols/triple/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/protocols/triple/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/protocols/triple/triple_protocol.py b/dubbo/protocols/triple/triple_protocol.py new file mode 100644 index 0000000..32b6043 --- /dev/null +++ b/dubbo/protocols/triple/triple_protocol.py @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dubbo.common.extensions.protocols_loader import ProtocolExtensionLoader +from dubbo.protocols.protocol import Protocol + + +@ProtocolExtensionLoader.register('tri') +class TripleProtocol(Protocol): + """ + Triple protocols. + """ + + def export(self, invoker): + raise NotImplementedError('export method is not implemented') + + def refer(self, service_type, url): + raise NotImplementedError('refer method is not implemented') diff --git a/dubbo/pydubbo.py b/dubbo/pydubbo.py new file mode 100644 index 0000000..4da89bf --- /dev/null +++ b/dubbo/pydubbo.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import imports.imports # Load the extensions. diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/common/__init__.py b/tests/common/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/tests/common/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/common/url_test.py b/tests/common/url_test.py new file mode 100644 index 0000000..09ac1ef --- /dev/null +++ b/tests/common/url_test.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from dubbo.common import url as dubbo_url + + +class TestURL(unittest.TestCase): + + def test_parse_url_with_params(self): + url = "registry://192.168.1.7:9090/org.apache.dubbo.service1?param1=value1¶m2=value2" + parsed = dubbo_url.parse_url(url) + self.assertEqual(parsed.protocol, "registry") + self.assertEqual(parsed.host, "192.168.1.7") + self.assertEqual(parsed.port, 9090) + self.assertEqual(parsed.path, "/org.apache.dubbo.service1") + self.assertEqual(parsed.params, {"param1": "value1", "param2": "value2"}) + self.assertEqual(parsed.username, "") + self.assertEqual(parsed.password, "") + self.assertEqual(parsed.to_str(), url) + + def test_parse_url_with_auth(self): + url = "http://username:password@10.20.130.230:8080/list?version=1.0.0" + parsed = dubbo_url.parse_url(url) + self.assertEqual(parsed.protocol, "http") + self.assertEqual(parsed.host, "10.20.130.230") + self.assertEqual(parsed.port, 8080) + self.assertEqual(parsed.path, "/list") + self.assertEqual(parsed.params, {"version": "1.0.0"}) + self.assertEqual(parsed.username, "username") + self.assertEqual(parsed.password, "password") + self.assertEqual(parsed.to_str(), url) + + def test_to_str_with_encoded(self): + url = "http://username:password@10.20.130.230:8080/list?version=1.0.0" + parsed = dubbo_url.parse_url(url) + encoded_url = parsed.to_str(encoded=True) + self.assertNotEqual(encoded_url, url) + self.assertTrue('%3F' in encoded_url) + + def test_to_str_without_params(self): + url = "http://www.example.com" + parsed = dubbo_url.parse_url(url) + self.assertEqual(parsed.protocol, "http") + self.assertEqual(parsed.host, "www.example.com") + self.assertEqual(parsed.path, "") + self.assertEqual(parsed.params, {}) + self.assertEqual(parsed.username, "") + self.assertEqual(parsed.password, "") + self.assertEqual(parsed.to_str(), "http://www.example.com") + + def test_parse_url_encoded(self): + encoded_url = "http%3A%2F%2Fwww.facebook.com%2Ffriends%3Fparam1%3Dvalue1%26param2%3Dvalue2" + parsed = dubbo_url.parse_url(encoded_url, encoded=True) + self.assertEqual(parsed.protocol, "http") + self.assertEqual(parsed.host, "www.facebook.com") + self.assertEqual(parsed.path, "/friends") + self.assertEqual(parsed.params, {"param1": "value1", "param2": "value2"}) + self.assertEqual(parsed.username, "") + self.assertEqual(parsed.password, "") + self.assertEqual(parsed.to_str(), "http://www.facebook.com/friends?param1=value1¶m2=value2") + + +if __name__ == '__main__': + unittest.main() From 1e36ffdbf21478a04fbfcf809b926d67d64ee22f Mon Sep 17 00:00:00 2001 From: zaki Date: Mon, 27 May 2024 22:52:25 +0800 Subject: [PATCH 02/32] fix: fix ci --- .flake8 | 5 ++++- dubbo/config/reference_config.py | 4 ---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.flake8 b/.flake8 index 3ab0e58..6aa0376 100644 --- a/.flake8 +++ b/.flake8 @@ -16,9 +16,12 @@ ignore = max-line-length = 120 exclude = + .idea, .git, __pycache__, docs per-file-ignores = - __init__.py:F401 \ No newline at end of file + __init__.py:F401 + dubbo/imports/imports.py:F401 + dubbo/pydubbo.py:F401 diff --git a/dubbo/config/reference_config.py b/dubbo/config/reference_config.py index b3f0f7c..45f3832 100644 --- a/dubbo/config/reference_config.py +++ b/dubbo/config/reference_config.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dubbo.protocols.protocol import Protocol - class ReferenceConfig: """ @@ -33,5 +31,3 @@ def __init__(self): self.invoker = None # The flag whether the ReferenceConfig has been initialized self.initialized = False - - From 4db92d7c82a98874cb62e4c1b8a485e3e601c8b2 Mon Sep 17 00:00:00 2001 From: zaki Date: Mon, 27 May 2024 23:09:32 +0800 Subject: [PATCH 03/32] feat: define UnaryUnaryMultiCallable --- dubbo/client/tri/client_call.py | 59 ++++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/dubbo/client/tri/client_call.py b/dubbo/client/tri/client_call.py index ee17f7b..d770270 100644 --- a/dubbo/client/tri/client_call.py +++ b/dubbo/client/tri/client_call.py @@ -33,6 +33,63 @@ def __call__( request: The request value for the RPC. timeout: An optional duration of time in seconds to allow for the RPC. compression: An element of dubbo.common.compression, e.g. 'gzip'. + + Returns: + The response value for the RPC. + + Raises: + RpcError: Indicating that the RPC terminated with non-OK status. The + raised RpcError will also be a Call for the RPC affording the RPC's + metadata, status code, and details. + """ + + raise NotImplementedError("Method '__call__' is not implemented.") + + @abc.abstractmethod + def with_call( + self, + request, + timeout=None, + compression=None + ): + """ + Synchronously invokes the underlying RPC. + Args: + request: The request value for the RPC. + timeout: An optional duration of time in seconds to allow for the RPC. + compression: An element of dubbo.common.compression, e.g. 'gzip'. + + Returns: + The response value for the RPC. + + Raises: + RpcError: Indicating that the RPC terminated with non-OK status. The + raised RpcError will also be a Call for the RPC affording the RPC's + metadata, status code, and details. """ - raise NotImplementedError() + raise NotImplementedError("Method 'with_call' is not implemented.") + + @abc.abstractmethod + def async_call( + self, + request, + timeout=None, + compression=None + ): + """ + Asynchronously invokes the underlying RPC. + Args: + request: The request value for the RPC. + timeout: An optional duration of time in seconds to allow for the RPC. + compression: An element of dubbo.common.compression, e.g. 'gzip'. + + Returns: + An object that is both a Call for the RPC and a Future. + In the event of RPC completion, the return Call-Future's result + value will be the response message of the RPC. + Should the event terminate with non-OK status, + the returned Call-Future's exception value will be an RpcError. + """ + + raise NotImplementedError("Method 'async_call' is not implemented.") From 81e22e9b81d077ddafcc5a62dddd8f6f12661054 Mon Sep 17 00:00:00 2001 From: zaki Date: Thu, 30 May 2024 23:00:09 +0800 Subject: [PATCH 04/32] feat: update applicationConfig --- dubbo/config/application_config.py | 39 ++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/dubbo/config/application_config.py b/dubbo/config/application_config.py index ce76327..7694f2c 100644 --- a/dubbo/config/application_config.py +++ b/dubbo/config/application_config.py @@ -18,17 +18,30 @@ class ApplicationConfig: """ Application Config """ + # name + name: str + # version + version: str + # owner + owner: str + # organization(BU) + organization: str + # architecture, e.g. intl, china + architecture: str + # environment, e.g. dev, test, production + environment: str - def __init__(self): - # name - self.name = '' - # version - self.version = '' - # owner - self.owner = '' - # organization(BU) - self.organization = '' - # architecture, e.g. intl, china - self.architecture = '' - # environment, e.g. dev, test, production - self.environment = '' + def __init__(self, **kwargs): + for key, value in kwargs.items(): + if key in self.__annotations__: + setattr(self, key, value) + else: + raise AttributeError(f"{key} is not a valid attribute of {self.__class__.__name__}") + + def __repr__(self): + return (f"") From 1b38707df53c07adac6d1384369c7b3edb3222a6 Mon Sep 17 00:00:00 2001 From: zaki Date: Sun, 2 Jun 2024 20:36:23 +0800 Subject: [PATCH 05/32] feat: do some work related to service reference --- .../python-lint-and-license-check.yml | 6 + .../imports.py => config/extensions.ini | 10 +- dubbo/__init__.py | 2 + dubbo/_dubbo.py | 62 ++++++++++ dubbo/common/extension.py | 111 ++++++++++++++++++ dubbo/common/extensions/extension.py | 49 -------- dubbo/common/extensions/protocols_loader.py | 64 ---------- dubbo/common/url.py | 14 ++- .../common/{extensions => utils}/__init__.py | 0 dubbo/common/utils/file_utils.py | 57 +++++++++ dubbo/config/config_manger.py | 40 +++++++ dubbo/config/reference_config.py | 2 + dubbo/{pydubbo.py => logger/__init__.py} | 2 +- dubbo/logger/logger.py | 59 ++++++++++ dubbo/logger/loguru_logger.py | 49 ++++++++ dubbo/protocols/triple/triple_protocol.py | 2 - {dubbo/imports => tests/logger}/__init__.py | 0 tests/logger/test_loguru_logger.py | 35 ++++++ 18 files changed, 435 insertions(+), 129 deletions(-) rename dubbo/imports/imports.py => config/extensions.ini (76%) create mode 100644 dubbo/_dubbo.py create mode 100644 dubbo/common/extension.py delete mode 100644 dubbo/common/extensions/extension.py delete mode 100644 dubbo/common/extensions/protocols_loader.py rename dubbo/common/{extensions => utils}/__init__.py (100%) create mode 100644 dubbo/common/utils/file_utils.py create mode 100644 dubbo/config/config_manger.py rename dubbo/{pydubbo.py => logger/__init__.py} (94%) create mode 100644 dubbo/logger/logger.py create mode 100644 dubbo/logger/loguru_logger.py rename {dubbo/imports => tests/logger}/__init__.py (100%) create mode 100644 tests/logger/test_loguru_logger.py diff --git a/.github/workflows/python-lint-and-license-check.yml b/.github/workflows/python-lint-and-license-check.yml index b552112..f9b6323 100644 --- a/.github/workflows/python-lint-and-license-check.yml +++ b/.github/workflows/python-lint-and-license-check.yml @@ -19,6 +19,12 @@ jobs: pip install flake8 flake8 . +# - name: Type check with MyPy +# run: | +# # fail if there are any MyPy errors +# pip install mypy +# mypy ./dubbo + check-license: runs-on: ubuntu-latest steps: diff --git a/dubbo/imports/imports.py b/config/extensions.ini similarity index 76% rename from dubbo/imports/imports.py rename to config/extensions.ini index 838183c..75a139d 100644 --- a/dubbo/imports/imports.py +++ b/config/extensions.ini @@ -14,10 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -This module provides a centralized collection of Dubbo SPI implementations. -It simplifies plugin installation using Python's import mechanism. -""" - -# Load Protocol Extension -from dubbo.protocols.triple.triple_protocol import TripleProtocol +# style: from a.b.c import D => a.b.c:D +[dubbo.logger:Logger] +loguru = dubbo.logger.loguru_logger:LoguruLogger \ No newline at end of file diff --git a/dubbo/__init__.py b/dubbo/__init__.py index bcba37a..2d866e1 100644 --- a/dubbo/__init__.py +++ b/dubbo/__init__.py @@ -13,3 +13,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from dubbo._dubbo import Dubbo diff --git a/dubbo/_dubbo.py b/dubbo/_dubbo.py new file mode 100644 index 0000000..11d58d6 --- /dev/null +++ b/dubbo/_dubbo.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from dubbo.config.application_config import ApplicationConfig +from dubbo.config.config_manger import ConfigManager +from dubbo.config.reference_config import ReferenceConfig + + +class Dubbo: + """ + Dubbo program entry. + """ + _instance = None + _lock: threading.Lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + """ + Singleton mode. + """ + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + self._config_manager: ConfigManager = ConfigManager() + + def with_application(self, application_config: ApplicationConfig) -> 'Dubbo': + """ + Set application configuration. + :return: Dubbo instance. + """ + self._config_manager.add_config(application_config) + return self + + def with_reference(self, reference_config: ReferenceConfig) -> 'Dubbo': + """ + Set reference configuration. + """ + self._config_manager.add_config(reference_config) + return self + + def start(self): + """ + Start Dubbo. + """ + pass diff --git a/dubbo/common/extension.py b/dubbo/common/extension.py new file mode 100644 index 0000000..9d54768 --- /dev/null +++ b/dubbo/common/extension.py @@ -0,0 +1,111 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from typing import Dict, Type + +from dubbo.common.utils.file_utils import IniFileUtils + + +def load_type(config_str: str) -> Type: + """ + Dynamically load a type from a module based on a configuration string. + + :param config_str: Configuration string in the format 'module_path:class_name'. + :return: The loaded type. + :raises ValueError: If the configuration string format is incorrect or the object is not a type. + :raises ImportError: If there is an error importing the specified module. + :raises AttributeError: If the specified attribute is not found in the module. + """ + module_path, class_name = '', '' + try: + # Split the configuration string to obtain the module path and object name + module_path, class_name = config_str.rsplit(':', 1) + + # Import the module + module = importlib.import_module(module_path) + + # Get the specified type from the module + loaded_type = getattr(module, class_name) + + # Ensure the loaded object is a type (class) + if not isinstance(loaded_type, type): + raise ValueError(f"'{class_name}' is not a valid type in module '{module_path}'") + + return loaded_type + except ValueError as e: + raise ValueError("Invalid configuration string. Use 'module_path:class_name' format.") from e + except ImportError as e: + raise ImportError(f"Error importing module '{module_path}': {e}") from e + except AttributeError as e: + raise AttributeError(f"Module '{module_path}' does not have an attribute '{class_name}'") from e + + +class ExtensionLoader: + """ + Extension loader. + """ + + def __init__(self, class_type: type, classes: Dict[str, str]): + self._class_type = class_type # class type + self._classes = {} + self._instances = {} + for name, config_str in classes.items(): + o = load_type(config_str) + if issubclass(o, class_type): + self._classes[name] = o + else: + raise ValueError(f"Class {class_type} is not a subclass of {object}") + + @property + def class_type(self): + return self._class_type + + @property + def classes(self): + return self._classes + + def get_instance(self, name: str): + if name not in self._instances: + self._instances[name] = self._classes[name]() + return self._instances[name] + + +class ExtensionManager: + """ + Extension manager. + """ + + def __init__(self): + self._extension_loaders: Dict[type, ExtensionLoader] = {} + + def initialize(self): + """ + Read the configuration file and initialize the extension manager. + """ + extensions = IniFileUtils.parse_config("extensions.ini") + for section, classes in extensions.items(): + class_type = load_type(section) + self._extension_loaders[class_type] = ExtensionLoader(class_type, classes) + + def get_extension_loader(self, class_type: type) -> ExtensionLoader: + """ + Get the extension loader for a given class object. + + :param class_type: Class object. + :return: Extension loader. + """ + return self._extension_loaders.get(class_type) diff --git a/dubbo/common/extensions/extension.py b/dubbo/common/extensions/extension.py deleted file mode 100644 index 4524516..0000000 --- a/dubbo/common/extensions/extension.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class ExtensionLoader: - """ - Extension loader Interface. - Any class that implements this interface can be called an extension loader. - """ - - @classmethod - def set(cls, name: str, extension): - """ - Set the extension. - :param name: The name of the extension. - :param extension: The extension. - """ - raise NotImplementedError("Method 'set' is not implemented.") - - @classmethod - def get(cls, name: str): - """ - Get the extension. - :param name: The name of the extension. - :return: The extension. - """ - raise NotImplementedError("Method 'get' is not implemented.") - - @classmethod - def register(cls, name: str): - """ - Register the extension. - This method is a decorator. - :param name: The name of the extension. - """ - raise NotImplementedError("Method 'register' is not implemented.") diff --git a/dubbo/common/extensions/protocols_loader.py b/dubbo/common/extensions/protocols_loader.py deleted file mode 100644 index f37dd1d..0000000 --- a/dubbo/common/extensions/protocols_loader.py +++ /dev/null @@ -1,64 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dubbo.common.extensions import extension -from dubbo.protocols.protocol import Protocol - - -class ProtocolExtensionLoader(extension.ExtensionLoader): - """ - Protocol extension loader. - """ - # Store the protocol classes. k: name, v: protocol class - __protocols: dict[str, type] = dict() - - @classmethod - def set(cls, name: str, protocol_class: type): - """ - Set the protocol. - :param name: The name of the protocols. - :param protocol_class: The protocol class. - """ - # Check if the protocol_class is a subclass of Protocol. - if not issubclass(protocol_class, Protocol): - raise TypeError(f"Need a subclass of Protocol, but got {protocol_class}") - cls.__protocols[name] = protocol_class - - @classmethod - def get(cls, name) -> Protocol: - """ - Get the protocols. - :param name: The name of the protocols. - :return: The protocol instance. - """ - try: - return cls.__protocols.get(name)() - except KeyError: - raise KeyError(f"Protocol extension not found: {name}") - - @classmethod - def register(cls, name: str): - """ - Register the protocols. - This method is a decorator. - :param name: The name of the protocols. - """ - - def decorator(protocol_class): - cls.set(name, protocol_class) - return protocol_class - - return decorator diff --git a/dubbo/common/url.py b/dubbo/common/url.py index 34e1694..090144b 100644 --- a/dubbo/common/url.py +++ b/dubbo/common/url.py @@ -23,10 +23,10 @@ def __init__(self, protocol: str, host: str, port: int, - username: str = None, - password: str = None, - path: str = None, - params: dict[str, str] = None + username: str = '', + password: str = '', + path: str = '', + params=None ): """ Initialize URL object. @@ -38,6 +38,8 @@ def __init__(self, :param path: path. :param params: parameters. """ + if params is None: + params = {} self.protocol = protocol self.host = host self.port = port @@ -87,6 +89,6 @@ def parse_url(url: str, encoded: bool = False) -> URL: port = parsed_url.port path = parsed_url.path params = {k: v[0] for k, v in ulp.parse_qs(parsed_url.query).items()} - username = parsed_url.username or '' - password = parsed_url.password or '' + username = parsed_url.username + password = parsed_url.password return URL(protocol, host, port, username, password, path, params) diff --git a/dubbo/common/extensions/__init__.py b/dubbo/common/utils/__init__.py similarity index 100% rename from dubbo/common/extensions/__init__.py rename to dubbo/common/utils/__init__.py diff --git a/dubbo/common/utils/file_utils.py b/dubbo/common/utils/file_utils.py new file mode 100644 index 0000000..ce98aca --- /dev/null +++ b/dubbo/common/utils/file_utils.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import configparser +from pathlib import Path +from typing import Dict + + +def get_dubbo_dir() -> Path: + """ + Get the dubbo directory. eg: /path/to/dubbo + """ + current_path = Path(__file__).resolve().parent + + for parent in current_path.parents: + if parent.name == "dubbo": + return parent + + raise FileNotFoundError("The 'dubbo' directory was not found in the path hierarchy.") + + +_CONFIG_DIR = get_dubbo_dir().parent / "config" + + +class IniFileUtils: + """ + Ini configuration file utils. + """ + + @staticmethod + def parse_config(file_name: str, file_dir: str = None, encoding: str = "utf-8") -> Dict[str, Dict[str, str]]: + """ + Parse the configuration file. + :param file_name: The name of the configuration file. + :param file_dir: The directory of the configuration file. + :param encoding: The encoding of the configuration file. + :return: The configuration. + """ + # get the file path + file_path = Path(file_dir) / file_name if file_dir else _CONFIG_DIR / file_name + # read the configuration file + cf = configparser.ConfigParser() + cf.read(file_path, encoding=encoding) + # get the configuration dict + return {section: dict(cf[section]) for section in cf.sections()} diff --git a/dubbo/config/config_manger.py b/dubbo/config/config_manger.py new file mode 100644 index 0000000..11fc536 --- /dev/null +++ b/dubbo/config/config_manger.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dubbo.config.application_config import ApplicationConfig + + +class ConfigManager: + """ + Configuration manager. + """ + # unique config in application + unique_config_types = [ + ApplicationConfig, + ] + + def __init__(self): + self._configs_cache = {} + + def add_config(self, config): + """ + Add configuration. + :param config: configuration. + """ + if type(config) not in self.unique_config_types or config.__class__ not in self._configs_cache: + self._configs_cache[type(config)] = config + else: + raise ValueError(f"Config type {type(config)} already exists.") diff --git a/dubbo/config/reference_config.py b/dubbo/config/reference_config.py index 45f3832..f364eda 100644 --- a/dubbo/config/reference_config.py +++ b/dubbo/config/reference_config.py @@ -25,6 +25,8 @@ def __init__(self): self.protocol = None # A ProxyFactory implementation that will generate a reference service's proxy self.pxy = None + # The interface of the reference service + self.method = None # The interface proxy reference self.ref = None # The invoker of the reference service diff --git a/dubbo/pydubbo.py b/dubbo/logger/__init__.py similarity index 94% rename from dubbo/pydubbo.py rename to dubbo/logger/__init__.py index 4da89bf..4c74427 100644 --- a/dubbo/pydubbo.py +++ b/dubbo/logger/__init__.py @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -import imports.imports # Load the extensions. +from .logger import Logger diff --git a/dubbo/logger/logger.py b/dubbo/logger/logger.py new file mode 100644 index 0000000..3221a9a --- /dev/null +++ b/dubbo/logger/logger.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class Logger: + + def log(self, level: str, msg: str) -> None: + """ + Log + """ + raise NotImplementedError("Method 'log' is not implemented.") + + def debug(self, msg: str) -> None: + """ + Debug log + """ + raise NotImplementedError("Method 'debug' is not implemented.") + + def info(self, msg: str) -> None: + """ + Info log + """ + raise NotImplementedError("Method 'info' is not implemented.") + + def warning(self, msg: str) -> None: + """ + Warning log + """ + raise NotImplementedError("Method 'warning' is not implemented.") + + def error(self, msg: str) -> None: + """ + Error log + """ + raise NotImplementedError("Method 'error' is not implemented.") + + def critical(self, msg: str) -> None: + """ + Critical log + """ + raise NotImplementedError("Method 'critical' is not implemented.") + + def exception(self, msg: str) -> None: + """ + Exception log + """ + raise NotImplementedError("Method 'exception' is not implemented.") diff --git a/dubbo/logger/loguru_logger.py b/dubbo/logger/loguru_logger.py new file mode 100644 index 0000000..12e62c2 --- /dev/null +++ b/dubbo/logger/loguru_logger.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from loguru import logger + +from dubbo.logger import Logger + + +class LoguruLogger(Logger): + """ + Loguru logger. + """ + + def __init__(self): + self.logger = logger.opt(depth=1) + + def log(self, level: str, msg: str) -> None: + self.logger.log(level, msg) + + def debug(self, msg: str) -> None: + self.logger.debug(msg) + + def info(self, msg: str) -> None: + self.logger.info(msg) + + def warning(self, msg: str) -> None: + self.logger.warning(msg) + + def error(self, msg: str) -> None: + self.logger.error(msg) + + def critical(self, msg: str) -> None: + self.logger.critical(msg) + + def exception(self, msg: str) -> None: + self.logger.exception(msg) diff --git a/dubbo/protocols/triple/triple_protocol.py b/dubbo/protocols/triple/triple_protocol.py index 32b6043..85357b8 100644 --- a/dubbo/protocols/triple/triple_protocol.py +++ b/dubbo/protocols/triple/triple_protocol.py @@ -14,11 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dubbo.common.extensions.protocols_loader import ProtocolExtensionLoader from dubbo.protocols.protocol import Protocol -@ProtocolExtensionLoader.register('tri') class TripleProtocol(Protocol): """ Triple protocols. diff --git a/dubbo/imports/__init__.py b/tests/logger/__init__.py similarity index 100% rename from dubbo/imports/__init__.py rename to tests/logger/__init__.py diff --git a/tests/logger/test_loguru_logger.py b/tests/logger/test_loguru_logger.py new file mode 100644 index 0000000..849fc58 --- /dev/null +++ b/tests/logger/test_loguru_logger.py @@ -0,0 +1,35 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from dubbo.logger.loguru_logger import LoguruLogger + + +class TestLoguruLogger(unittest.TestCase): + + def test_loguru_logger(self): + logger = LoguruLogger() + logger.debug("Debug log") + logger.info("Info log") + logger.warning("Warning log") + logger.error("Error log") + logger.critical("Critical log") + try: + return 1 / 0 + except ZeroDivisionError: + logger.exception("exception!!!") + assert True From 2f48f0ee54247366b3d16c03bcee70451e1682f7 Mon Sep 17 00:00:00 2001 From: zaki Date: Sun, 2 Jun 2024 20:48:49 +0800 Subject: [PATCH 06/32] feat: add ci --- .github/workflows/unittest.yml | 22 ++++++++++++++++++++++ dubbo/common/url.py | 16 +++++++--------- tests/common/{url_test.py => test_url.py} | 0 3 files changed, 29 insertions(+), 9 deletions(-) create mode 100644 .github/workflows/unittest.yml rename tests/common/{url_test.py => test_url.py} (100%) diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml new file mode 100644 index 0000000..3b5481b --- /dev/null +++ b/.github/workflows/unittest.yml @@ -0,0 +1,22 @@ +name: Run Unittests + +on: [push, pull_request] + +jobs: + unittest: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + pip install -r requirements.txt + + - name: Run unittests + run: | + python -m unittest discover -s tests -p 'test_*.py' diff --git a/dubbo/common/url.py b/dubbo/common/url.py index 090144b..b3c3594 100644 --- a/dubbo/common/url.py +++ b/dubbo/common/url.py @@ -23,10 +23,10 @@ def __init__(self, protocol: str, host: str, port: int, - username: str = '', - password: str = '', - path: str = '', - params=None + username: str = None, + password: str = None, + path: str = None, + params: dict[str, str] = None ): """ Initialize URL object. @@ -38,8 +38,6 @@ def __init__(self, :param path: path. :param params: parameters. """ - if params is None: - params = {} self.protocol = protocol self.host = host self.port = port @@ -58,7 +56,7 @@ def to_str(self, encoded: bool = False) -> str: # Set username and password auth_part = f"{self.username}:{self.password}@" if self.username or self.password else "" # Set location - netloc = f"{auth_part}{self.host}{self.port}" + netloc = f"{auth_part}{self.host}{':' + str(self.port) if self.port else ''}" query = ulp.urlencode(self.params) path = self.path @@ -89,6 +87,6 @@ def parse_url(url: str, encoded: bool = False) -> URL: port = parsed_url.port path = parsed_url.path params = {k: v[0] for k, v in ulp.parse_qs(parsed_url.query).items()} - username = parsed_url.username - password = parsed_url.password + username = parsed_url.username or '' + password = parsed_url.password or '' return URL(protocol, host, port, username, password, path, params) diff --git a/tests/common/url_test.py b/tests/common/test_url.py similarity index 100% rename from tests/common/url_test.py rename to tests/common/test_url.py From ca6172614df1f7194ae3ddda5e92d067a840841a Mon Sep 17 00:00:00 2001 From: zaki Date: Sun, 2 Jun 2024 20:50:20 +0800 Subject: [PATCH 07/32] fix: fix ci --- requirements.txt | 1 + 1 file changed, 1 insertion(+) create mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a38bb99 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +loguru~=0.7.2 \ No newline at end of file From 31172fc73a2a6f335e465ab966b33152369fa402 Mon Sep 17 00:00:00 2001 From: zaki Date: Sun, 2 Jun 2024 20:52:31 +0800 Subject: [PATCH 08/32] fix: fix ci --- .licenserc.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.licenserc.yaml b/.licenserc.yaml index 35f2542..0ef3499 100644 --- a/.licenserc.yaml +++ b/.licenserc.yaml @@ -61,6 +61,7 @@ header: # `header` section is configurations for source codes license header. - '.gitignore' - '.github' - '.flake8' + - 'requirements.txt' comment: on-failure # on what condition license-eye will comment on the pull request, `on-failure`, `always`, `never`. # license-location-threshold specifies the index threshold where the license header can be located, From 7f3ee0173ced83e302687c13202a3e155f8f3eba Mon Sep 17 00:00:00 2001 From: zaki Date: Sun, 2 Jun 2024 22:41:28 +0800 Subject: [PATCH 09/32] feat: add logger feat --- dubbo/config/application_config.py | 78 +++++++++++++++++-------- dubbo/logger/__init__.py | 2 +- dubbo/logger/{logger.py => _logger.py} | 22 +++++++ test.py | 26 +++++++++ tests/config/__init__.py | 15 +++++ tests/config/test_application_config.py | 31 ++++++++++ 6 files changed, 148 insertions(+), 26 deletions(-) rename dubbo/logger/{logger.py => _logger.py} (83%) create mode 100644 test.py create mode 100644 tests/config/__init__.py create mode 100644 tests/config/test_application_config.py diff --git a/dubbo/config/application_config.py b/dubbo/config/application_config.py index 7694f2c..bd03648 100644 --- a/dubbo/config/application_config.py +++ b/dubbo/config/application_config.py @@ -13,35 +13,63 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dubbo import logger +from dubbo.common.extension import ExtensionManager + class ApplicationConfig: """ Application Config """ - # name - name: str - # version - version: str - # owner - owner: str - # organization(BU) - organization: str - # architecture, e.g. intl, china - architecture: str - # environment, e.g. dev, test, production - environment: str - - def __init__(self, **kwargs): - for key, value in kwargs.items(): - if key in self.__annotations__: - setattr(self, key, value) - else: - raise AttributeError(f"{key} is not a valid attribute of {self.__class__.__name__}") + + def __init__( + self, + name: str, + version: str = '', + owner: str = '', + organization: str = '', + architecture: str = '', + environment: str = '', + logger_name: str = 'loguru'): + self._name = name + self._version = version + self._owner = owner + self._organization = organization + self._architecture = architecture + self._environment = environment + self._logger_name = logger_name + self._extension_manager = ExtensionManager() + + # init application config + self.do_init() + + def do_init(self): + # init ExtensionManager + self._extension_manager.initialize() + # init logger + self.init_logger(self._logger_name) + + @property + def logger_name(self): + return self._logger_name + + @logger_name.setter + def logger_name(self, logger_name: str): + self._logger_name = logger_name + self.init_logger(logger_name) + + def init_logger(self, logger_name: str): + """ + Init logger + """ + # init dubbo logger + instance = self._extension_manager.get_extension_loader(logger.Logger).get_instance(logger_name) + logger.set_logger(instance) def __repr__(self): - return (f"") + return (f"") diff --git a/dubbo/logger/__init__.py b/dubbo/logger/__init__.py index 4c74427..0dbbad3 100644 --- a/dubbo/logger/__init__.py +++ b/dubbo/logger/__init__.py @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .logger import Logger +from ._logger import Logger, set_logger, get_logger diff --git a/dubbo/logger/logger.py b/dubbo/logger/_logger.py similarity index 83% rename from dubbo/logger/logger.py rename to dubbo/logger/_logger.py index 3221a9a..72d7163 100644 --- a/dubbo/logger/logger.py +++ b/dubbo/logger/_logger.py @@ -57,3 +57,25 @@ def exception(self, msg: str) -> None: Exception log """ raise NotImplementedError("Method 'exception' is not implemented.") + + +# global logger, default logger is None +_LOGGER: Logger = Logger() + + +def get_logger() -> Logger: + """ + Get logger + """ + return _LOGGER + + +def set_logger(logger: Logger) -> None: + """ + Set logger + """ + global _LOGGER + if logger is not None and isinstance(logger, Logger): + _LOGGER = logger + else: + raise ValueError("Invalid logger") diff --git a/test.py b/test.py new file mode 100644 index 0000000..e551af8 --- /dev/null +++ b/test.py @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dubbo.config.application_config import ApplicationConfig +from dubbo.logger import get_logger + +if __name__ == '__main__': + ApplicationConfig(name='dubbo') + dubbo_logger = get_logger() + dubbo_logger.debug('debug') + dubbo_logger.info('info') + dubbo_logger.warning('warning') + dubbo_logger.error('error') \ No newline at end of file diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/tests/config/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/config/test_application_config.py b/tests/config/test_application_config.py new file mode 100644 index 0000000..7922a77 --- /dev/null +++ b/tests/config/test_application_config.py @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from dubbo.config.application_config import ApplicationConfig +from dubbo import logger + + +class TestApplicationConfig(unittest.TestCase): + + def test_init_logger(self): + ApplicationConfig(name='dubbo') + dubbo_logger = logger.get_logger() + dubbo_logger.debug('debug') + dubbo_logger.info('info') + dubbo_logger.warning('warning') + dubbo_logger.error('error') + assert True From 9bc8bdd9c6379e77c28e4cf3fcb04bc3ed709b29 Mon Sep 17 00:00:00 2001 From: zaki Date: Sun, 2 Jun 2024 22:47:14 +0800 Subject: [PATCH 10/32] fix: fix ci --- dubbo/common/extension.py | 11 ++++++++++- dubbo/config/application_config.py | 11 ++++++----- test.py | 26 -------------------------- 3 files changed, 16 insertions(+), 32 deletions(-) delete mode 100644 test.py diff --git a/dubbo/common/extension.py b/dubbo/common/extension.py index 9d54768..512a035 100644 --- a/dubbo/common/extension.py +++ b/dubbo/common/extension.py @@ -15,6 +15,7 @@ # limitations under the License. import importlib +import threading from typing import Dict, Type from dubbo.common.utils.file_utils import IniFileUtils @@ -63,6 +64,7 @@ def __init__(self, class_type: type, classes: Dict[str, str]): self._class_type = class_type # class type self._classes = {} self._instances = {} + self._instance_lock = threading.Lock() for name, config_str in classes.items(): o = load_type(config_str) if issubclass(o, class_type): @@ -79,8 +81,15 @@ def classes(self): return self._classes def get_instance(self, name: str): + # check if the class exists + if name not in self._classes: + raise ValueError(f"Class {name} not found in {self._class_type}") + + # get the instance if name not in self._instances: - self._instances[name] = self._classes[name]() + with self._instance_lock: + if name not in self._instances: + self._instances[name] = self._classes[name]() return self._instances[name] diff --git a/dubbo/config/application_config.py b/dubbo/config/application_config.py index bd03648..2bb7352 100644 --- a/dubbo/config/application_config.py +++ b/dubbo/config/application_config.py @@ -47,7 +47,7 @@ def do_init(self): # init ExtensionManager self._extension_manager.initialize() # init logger - self.init_logger(self._logger_name) + self._update_logger(self._logger_name) @property def logger_name(self): @@ -56,14 +56,15 @@ def logger_name(self): @logger_name.setter def logger_name(self, logger_name: str): self._logger_name = logger_name - self.init_logger(logger_name) + self._update_logger(logger_name) - def init_logger(self, logger_name: str): + def _update_logger(self, logger_name: str): """ - Init logger + Update global logger instance. """ - # init dubbo logger + # get logger instance instance = self._extension_manager.get_extension_loader(logger.Logger).get_instance(logger_name) + # update logger logger.set_logger(instance) def __repr__(self): diff --git a/test.py b/test.py deleted file mode 100644 index e551af8..0000000 --- a/test.py +++ /dev/null @@ -1,26 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dubbo.config.application_config import ApplicationConfig -from dubbo.logger import get_logger - -if __name__ == '__main__': - ApplicationConfig(name='dubbo') - dubbo_logger = get_logger() - dubbo_logger.debug('debug') - dubbo_logger.info('info') - dubbo_logger.warning('warning') - dubbo_logger.error('error') \ No newline at end of file From d66a40ef3c4c296cb92a7a0328555c4e50c2d9d8 Mon Sep 17 00:00:00 2001 From: zaki Date: Mon, 3 Jun 2024 00:06:25 +0800 Subject: [PATCH 11/32] perf: Extension Manager becomes a singleton --- config/extensions.ini | 5 ++-- dubbo/_dubbo.py | 9 ++++++- dubbo/common/extension.py | 31 ++++++++++++++++++++++- dubbo/config/application_config.py | 24 +++++------------- dubbo/logger/__init__.py | 2 +- dubbo/logger/_logger.py | 33 ++++++++++++++++++------- tests/common/test_extension.py | 31 +++++++++++++++++++++++ tests/config/test_application_config.py | 4 ++- 8 files changed, 105 insertions(+), 34 deletions(-) create mode 100644 tests/common/test_extension.py diff --git a/config/extensions.ini b/config/extensions.ini index 75a139d..77c1749 100644 --- a/config/extensions.ini +++ b/config/extensions.ini @@ -14,6 +14,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -# style: from a.b.c import D => a.b.c:D -[dubbo.logger:Logger] -loguru = dubbo.logger.loguru_logger:LoguruLogger \ No newline at end of file +[dubbo.logger.Logger] +loguru = dubbo.logger.loguru_logger.LoguruLogger \ No newline at end of file diff --git a/dubbo/_dubbo.py b/dubbo/_dubbo.py index 11d58d6..e80a826 100644 --- a/dubbo/_dubbo.py +++ b/dubbo/_dubbo.py @@ -55,8 +55,15 @@ def with_reference(self, reference_config: ReferenceConfig) -> 'Dubbo': self._config_manager.add_config(reference_config) return self + def _do_init(self): + """ + Initialize Dubbo. + """ + pass + def start(self): """ Start Dubbo. """ - pass + self._do_init() + diff --git a/dubbo/common/extension.py b/dubbo/common/extension.py index 512a035..c7f352f 100644 --- a/dubbo/common/extension.py +++ b/dubbo/common/extension.py @@ -34,7 +34,7 @@ def load_type(config_str: str) -> Type: module_path, class_name = '', '' try: # Split the configuration string to obtain the module path and object name - module_path, class_name = config_str.rsplit(':', 1) + module_path, class_name = config_str.rsplit('.', 1) # Import the module module = importlib.import_module(module_path) @@ -99,16 +99,26 @@ class ExtensionManager: """ def __init__(self): + self._initialized = False self._extension_loaders: Dict[type, ExtensionLoader] = {} + @property + def initialized(self): + return self._initialized + def initialize(self): """ Read the configuration file and initialize the extension manager. """ + if self._initialized: + return + # read the configuration file extensions = IniFileUtils.parse_config("extensions.ini") + # parse the configuration for section, classes in extensions.items(): class_type = load_type(section) self._extension_loaders[class_type] = ExtensionLoader(class_type, classes) + self._initialized = True def get_extension_loader(self, class_type: type) -> ExtensionLoader: """ @@ -118,3 +128,22 @@ def get_extension_loader(self, class_type: type) -> ExtensionLoader: :return: Extension loader. """ return self._extension_loaders.get(class_type) + + +# global extension manager +_EXTENSION_MANAGER = ExtensionManager() +# lock +_lock = threading.Lock() + + +def get_extension_manager() -> ExtensionManager: + """ + Get the extension manager. + + :return: Extension manager. + """ + if not _EXTENSION_MANAGER.initialized: + with _lock: + if not _EXTENSION_MANAGER.initialized: + _EXTENSION_MANAGER.initialize() + return _EXTENSION_MANAGER \ No newline at end of file diff --git a/dubbo/config/application_config.py b/dubbo/config/application_config.py index 2bb7352..f7df8ad 100644 --- a/dubbo/config/application_config.py +++ b/dubbo/config/application_config.py @@ -13,8 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from dubbo import logger -from dubbo.common.extension import ExtensionManager +from dubbo.common import extension + +extension_manager = extension.get_extension_manager() class ApplicationConfig: @@ -38,16 +41,10 @@ def __init__( self._architecture = architecture self._environment = environment self._logger_name = logger_name - self._extension_manager = ExtensionManager() - - # init application config - self.do_init() def do_init(self): - # init ExtensionManager - self._extension_manager.initialize() # init logger - self._update_logger(self._logger_name) + logger.set_logger_by_name(self.logger_name) @property def logger_name(self): @@ -56,16 +53,7 @@ def logger_name(self): @logger_name.setter def logger_name(self, logger_name: str): self._logger_name = logger_name - self._update_logger(logger_name) - - def _update_logger(self, logger_name: str): - """ - Update global logger instance. - """ - # get logger instance - instance = self._extension_manager.get_extension_loader(logger.Logger).get_instance(logger_name) - # update logger - logger.set_logger(instance) + logger.set_logger_by_name(logger_name) def __repr__(self): return (f" None: """ @@ -59,23 +63,34 @@ def exception(self, msg: str) -> None: raise NotImplementedError("Method 'exception' is not implemented.") -# global logger, default logger is None +# global logger, default logger is Logger(), so it will raise an error if it is not set _LOGGER: Logger = Logger() -def get_logger() -> Logger: - """ - Get logger - """ - return _LOGGER - - def set_logger(logger: Logger) -> None: """ - Set logger + Set global logger """ global _LOGGER if logger is not None and isinstance(logger, Logger): _LOGGER = logger else: raise ValueError("Invalid logger") + + +def set_logger_by_name(logger_name: str) -> None: + """ + Set global logger by name + """ + # import extension module here to avoid circular import + from dubbo.common import extension + extension_manager = extension.get_extension_manager() + instance = extension_manager.get_extension_loader(Logger).get_instance(logger_name) + set_logger(instance) + + +def get_logger() -> Logger: + """ + Get global logger + """ + return _LOGGER diff --git a/tests/common/test_extension.py b/tests/common/test_extension.py new file mode 100644 index 0000000..63fc929 --- /dev/null +++ b/tests/common/test_extension.py @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from dubbo.common import extension +from dubbo import logger + + +class TestExtension(unittest.TestCase): + + def test_get_instance(self): + manager = extension.get_extension_manager() + assert manager is not None + loader = manager.get_extension_loader(logger.Logger) + assert loader is not None + dubbo_logger = loader.get_instance("loguru") + assert dubbo_logger is not None diff --git a/tests/config/test_application_config.py b/tests/config/test_application_config.py index 7922a77..d58e0cf 100644 --- a/tests/config/test_application_config.py +++ b/tests/config/test_application_config.py @@ -15,6 +15,7 @@ # limitations under the License. import unittest +from dubbo.common import extension from dubbo.config.application_config import ApplicationConfig from dubbo import logger @@ -22,7 +23,8 @@ class TestApplicationConfig(unittest.TestCase): def test_init_logger(self): - ApplicationConfig(name='dubbo') + config = ApplicationConfig(name='dubbo') + config.do_init() dubbo_logger = logger.get_logger() dubbo_logger.debug('debug') dubbo_logger.info('info') From 8ad1133e7b2979ccbc8318c0d247968a654a422e Mon Sep 17 00:00:00 2001 From: zaki Date: Mon, 3 Jun 2024 00:08:11 +0800 Subject: [PATCH 12/32] fix: fix ci --- dubbo/_dubbo.py | 1 - dubbo/common/extension.py | 2 +- tests/config/test_application_config.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/dubbo/_dubbo.py b/dubbo/_dubbo.py index e80a826..a7915a9 100644 --- a/dubbo/_dubbo.py +++ b/dubbo/_dubbo.py @@ -66,4 +66,3 @@ def start(self): Start Dubbo. """ self._do_init() - diff --git a/dubbo/common/extension.py b/dubbo/common/extension.py index c7f352f..1d4b659 100644 --- a/dubbo/common/extension.py +++ b/dubbo/common/extension.py @@ -146,4 +146,4 @@ def get_extension_manager() -> ExtensionManager: with _lock: if not _EXTENSION_MANAGER.initialized: _EXTENSION_MANAGER.initialize() - return _EXTENSION_MANAGER \ No newline at end of file + return _EXTENSION_MANAGER diff --git a/tests/config/test_application_config.py b/tests/config/test_application_config.py index d58e0cf..3c49553 100644 --- a/tests/config/test_application_config.py +++ b/tests/config/test_application_config.py @@ -15,7 +15,6 @@ # limitations under the License. import unittest -from dubbo.common import extension from dubbo.config.application_config import ApplicationConfig from dubbo import logger From 351c0a22d9cb5016e8ee9e7638097eeaa2d234d8 Mon Sep 17 00:00:00 2001 From: zaki Date: Thu, 13 Jun 2024 13:45:31 +0800 Subject: [PATCH 13/32] refactor: Make the code more standardized and robust --- .../python-lint-and-license-check.yml | 10 +- config/extensions.ini | 18 --- dubbo/__init__.py | 2 - dubbo/_dubbo.py | 68 -------- dubbo/client/__init__.py | 15 -- dubbo/client/tri/__init__.py | 15 -- dubbo/client/tri/client_call.py | 95 ----------- dubbo/common/compression/__init__.py | 15 -- dubbo/common/compression/compression.py | 37 ----- dubbo/common/compression/gzip.py | 39 ----- dubbo/common/extension.py | 149 ------------------ .../common/{config => extension}/__init__.py | 1 + .../logger_extension.py} | 41 +++-- dubbo/common/url.py | 92 ----------- dubbo/common/utils/__init__.py | 15 -- dubbo/common/utils/file_utils.py | 57 ------- dubbo/config/__init__.py | 15 -- dubbo/config/application_config.py | 64 -------- dubbo/config/config_manger.py | 40 ----- dubbo/config/protocol_config.py | 44 ------ dubbo/config/reference_config.py | 35 ---- dubbo/logger/__init__.py | 3 +- dubbo/logger/_logger.py | 90 +++++------ dubbo/logger/internal_logger.py | 69 ++++++++ dubbo/logger/loguru_logger.py | 49 ------ dubbo/protocols/__init__.py | 15 -- dubbo/protocols/invoker.py | 35 ---- dubbo/protocols/protocol.py | 39 ----- dubbo/protocols/triple/__init__.py | 15 -- dubbo/protocols/triple/triple_protocol.py | 29 ---- dubbo/{protocols/invocation.py => run.py} | 5 +- requirements.txt | 1 - tests/common/__init__.py | 15 -- tests/common/test_extension.py | 31 ---- tests/common/test_url.py | 78 --------- tests/config/__init__.py | 15 -- tests/config/test_application_config.py | 32 ---- ...guru_logger.py => test_internal_logger.py} | 25 ++- 38 files changed, 151 insertions(+), 1262 deletions(-) delete mode 100644 config/extensions.ini delete mode 100644 dubbo/_dubbo.py delete mode 100644 dubbo/client/__init__.py delete mode 100644 dubbo/client/tri/__init__.py delete mode 100644 dubbo/client/tri/client_call.py delete mode 100644 dubbo/common/compression/__init__.py delete mode 100644 dubbo/common/compression/compression.py delete mode 100644 dubbo/common/compression/gzip.py delete mode 100644 dubbo/common/extension.py rename dubbo/common/{config => extension}/__init__.py (93%) rename dubbo/common/{node.py => extension/logger_extension.py} (59%) delete mode 100644 dubbo/common/url.py delete mode 100644 dubbo/common/utils/__init__.py delete mode 100644 dubbo/common/utils/file_utils.py delete mode 100644 dubbo/config/__init__.py delete mode 100644 dubbo/config/application_config.py delete mode 100644 dubbo/config/config_manger.py delete mode 100644 dubbo/config/protocol_config.py delete mode 100644 dubbo/config/reference_config.py create mode 100644 dubbo/logger/internal_logger.py delete mode 100644 dubbo/logger/loguru_logger.py delete mode 100644 dubbo/protocols/__init__.py delete mode 100644 dubbo/protocols/invoker.py delete mode 100644 dubbo/protocols/protocol.py delete mode 100644 dubbo/protocols/triple/__init__.py delete mode 100644 dubbo/protocols/triple/triple_protocol.py rename dubbo/{protocols/invocation.py => run.py} (92%) delete mode 100644 tests/common/__init__.py delete mode 100644 tests/common/test_extension.py delete mode 100644 tests/common/test_url.py delete mode 100644 tests/config/__init__.py delete mode 100644 tests/config/test_application_config.py rename tests/logger/{test_loguru_logger.py => test_internal_logger.py} (65%) diff --git a/.github/workflows/python-lint-and-license-check.yml b/.github/workflows/python-lint-and-license-check.yml index f9b6323..1cbb9cd 100644 --- a/.github/workflows/python-lint-and-license-check.yml +++ b/.github/workflows/python-lint-and-license-check.yml @@ -19,11 +19,11 @@ jobs: pip install flake8 flake8 . -# - name: Type check with MyPy -# run: | -# # fail if there are any MyPy errors -# pip install mypy -# mypy ./dubbo + - name: Type check with MyPy + run: | + # fail if there are any MyPy errors + pip install mypy + mypy ./dubbo check-license: runs-on: ubuntu-latest diff --git a/config/extensions.ini b/config/extensions.ini deleted file mode 100644 index 77c1749..0000000 --- a/config/extensions.ini +++ /dev/null @@ -1,18 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -[dubbo.logger.Logger] -loguru = dubbo.logger.loguru_logger.LoguruLogger \ No newline at end of file diff --git a/dubbo/__init__.py b/dubbo/__init__.py index 2d866e1..bcba37a 100644 --- a/dubbo/__init__.py +++ b/dubbo/__init__.py @@ -13,5 +13,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from dubbo._dubbo import Dubbo diff --git a/dubbo/_dubbo.py b/dubbo/_dubbo.py deleted file mode 100644 index a7915a9..0000000 --- a/dubbo/_dubbo.py +++ /dev/null @@ -1,68 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import threading -from dubbo.config.application_config import ApplicationConfig -from dubbo.config.config_manger import ConfigManager -from dubbo.config.reference_config import ReferenceConfig - - -class Dubbo: - """ - Dubbo program entry. - """ - _instance = None - _lock: threading.Lock = threading.Lock() - - def __new__(cls, *args, **kwargs): - """ - Singleton mode. - """ - if cls._instance is None: - with cls._lock: - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self): - self._config_manager: ConfigManager = ConfigManager() - - def with_application(self, application_config: ApplicationConfig) -> 'Dubbo': - """ - Set application configuration. - :return: Dubbo instance. - """ - self._config_manager.add_config(application_config) - return self - - def with_reference(self, reference_config: ReferenceConfig) -> 'Dubbo': - """ - Set reference configuration. - """ - self._config_manager.add_config(reference_config) - return self - - def _do_init(self): - """ - Initialize Dubbo. - """ - pass - - def start(self): - """ - Start Dubbo. - """ - self._do_init() diff --git a/dubbo/client/__init__.py b/dubbo/client/__init__.py deleted file mode 100644 index bcba37a..0000000 --- a/dubbo/client/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/dubbo/client/tri/__init__.py b/dubbo/client/tri/__init__.py deleted file mode 100644 index bcba37a..0000000 --- a/dubbo/client/tri/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/dubbo/client/tri/client_call.py b/dubbo/client/tri/client_call.py deleted file mode 100644 index d770270..0000000 --- a/dubbo/client/tri/client_call.py +++ /dev/null @@ -1,95 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc - - -class UnaryUnaryMultiCallable(abc.ABC): - """Affords invoking a unary-unary RPC from client-side.""" - - @abc.abstractmethod - def __call__( - self, - request, - timeout=None, - compression=None - ): - """ - Synchronously invokes the underlying RPC. - Args: - request: The request value for the RPC. - timeout: An optional duration of time in seconds to allow for the RPC. - compression: An element of dubbo.common.compression, e.g. 'gzip'. - - Returns: - The response value for the RPC. - - Raises: - RpcError: Indicating that the RPC terminated with non-OK status. The - raised RpcError will also be a Call for the RPC affording the RPC's - metadata, status code, and details. - """ - - raise NotImplementedError("Method '__call__' is not implemented.") - - @abc.abstractmethod - def with_call( - self, - request, - timeout=None, - compression=None - ): - """ - Synchronously invokes the underlying RPC. - Args: - request: The request value for the RPC. - timeout: An optional duration of time in seconds to allow for the RPC. - compression: An element of dubbo.common.compression, e.g. 'gzip'. - - Returns: - The response value for the RPC. - - Raises: - RpcError: Indicating that the RPC terminated with non-OK status. The - raised RpcError will also be a Call for the RPC affording the RPC's - metadata, status code, and details. - """ - - raise NotImplementedError("Method 'with_call' is not implemented.") - - @abc.abstractmethod - def async_call( - self, - request, - timeout=None, - compression=None - ): - """ - Asynchronously invokes the underlying RPC. - Args: - request: The request value for the RPC. - timeout: An optional duration of time in seconds to allow for the RPC. - compression: An element of dubbo.common.compression, e.g. 'gzip'. - - Returns: - An object that is both a Call for the RPC and a Future. - In the event of RPC completion, the return Call-Future's result - value will be the response message of the RPC. - Should the event terminate with non-OK status, - the returned Call-Future's exception value will be an RpcError. - """ - - raise NotImplementedError("Method 'async_call' is not implemented.") diff --git a/dubbo/common/compression/__init__.py b/dubbo/common/compression/__init__.py deleted file mode 100644 index bcba37a..0000000 --- a/dubbo/common/compression/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/dubbo/common/compression/compression.py b/dubbo/common/compression/compression.py deleted file mode 100644 index ed1569d..0000000 --- a/dubbo/common/compression/compression.py +++ /dev/null @@ -1,37 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc - - -class Compression(abc.ABC): - """Compression interface.""" - - def compress(self, data: bytes) -> bytes: - """ - Compress data. - :param data: data to be compressed. - :return: compressed data. - """ - raise NotImplementedError("Method 'compress' is not implemented.") - - def decompress(self, data: bytes) -> bytes: - """ - Decompress data. - :param data: data to be decompressed. - :return: decompressed data. - """ - raise NotImplementedError("Method 'decompress' is not implemented.") diff --git a/dubbo/common/compression/gzip.py b/dubbo/common/compression/gzip.py deleted file mode 100644 index 099fa8a..0000000 --- a/dubbo/common/compression/gzip.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gzip - -from dubbo.common.compression.compression import Compression - - -class GzipCompression(Compression): - """Gzip compression implementation.""" - - def compress(self, data: bytes) -> bytes: - """ - Compress data using gzip. - :param data: data to be compressed. - :return: compressed data. - """ - return gzip.compress(data) - - def decompress(self, data: bytes) -> bytes: - """ - Decompress data using gzip. - :param data: data to be decompressed. - :return: decompressed data. - """ - return gzip.decompress(data) diff --git a/dubbo/common/extension.py b/dubbo/common/extension.py deleted file mode 100644 index 1d4b659..0000000 --- a/dubbo/common/extension.py +++ /dev/null @@ -1,149 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -import threading -from typing import Dict, Type - -from dubbo.common.utils.file_utils import IniFileUtils - - -def load_type(config_str: str) -> Type: - """ - Dynamically load a type from a module based on a configuration string. - - :param config_str: Configuration string in the format 'module_path:class_name'. - :return: The loaded type. - :raises ValueError: If the configuration string format is incorrect or the object is not a type. - :raises ImportError: If there is an error importing the specified module. - :raises AttributeError: If the specified attribute is not found in the module. - """ - module_path, class_name = '', '' - try: - # Split the configuration string to obtain the module path and object name - module_path, class_name = config_str.rsplit('.', 1) - - # Import the module - module = importlib.import_module(module_path) - - # Get the specified type from the module - loaded_type = getattr(module, class_name) - - # Ensure the loaded object is a type (class) - if not isinstance(loaded_type, type): - raise ValueError(f"'{class_name}' is not a valid type in module '{module_path}'") - - return loaded_type - except ValueError as e: - raise ValueError("Invalid configuration string. Use 'module_path:class_name' format.") from e - except ImportError as e: - raise ImportError(f"Error importing module '{module_path}': {e}") from e - except AttributeError as e: - raise AttributeError(f"Module '{module_path}' does not have an attribute '{class_name}'") from e - - -class ExtensionLoader: - """ - Extension loader. - """ - - def __init__(self, class_type: type, classes: Dict[str, str]): - self._class_type = class_type # class type - self._classes = {} - self._instances = {} - self._instance_lock = threading.Lock() - for name, config_str in classes.items(): - o = load_type(config_str) - if issubclass(o, class_type): - self._classes[name] = o - else: - raise ValueError(f"Class {class_type} is not a subclass of {object}") - - @property - def class_type(self): - return self._class_type - - @property - def classes(self): - return self._classes - - def get_instance(self, name: str): - # check if the class exists - if name not in self._classes: - raise ValueError(f"Class {name} not found in {self._class_type}") - - # get the instance - if name not in self._instances: - with self._instance_lock: - if name not in self._instances: - self._instances[name] = self._classes[name]() - return self._instances[name] - - -class ExtensionManager: - """ - Extension manager. - """ - - def __init__(self): - self._initialized = False - self._extension_loaders: Dict[type, ExtensionLoader] = {} - - @property - def initialized(self): - return self._initialized - - def initialize(self): - """ - Read the configuration file and initialize the extension manager. - """ - if self._initialized: - return - # read the configuration file - extensions = IniFileUtils.parse_config("extensions.ini") - # parse the configuration - for section, classes in extensions.items(): - class_type = load_type(section) - self._extension_loaders[class_type] = ExtensionLoader(class_type, classes) - self._initialized = True - - def get_extension_loader(self, class_type: type) -> ExtensionLoader: - """ - Get the extension loader for a given class object. - - :param class_type: Class object. - :return: Extension loader. - """ - return self._extension_loaders.get(class_type) - - -# global extension manager -_EXTENSION_MANAGER = ExtensionManager() -# lock -_lock = threading.Lock() - - -def get_extension_manager() -> ExtensionManager: - """ - Get the extension manager. - - :return: Extension manager. - """ - if not _EXTENSION_MANAGER.initialized: - with _lock: - if not _EXTENSION_MANAGER.initialized: - _EXTENSION_MANAGER.initialize() - return _EXTENSION_MANAGER diff --git a/dubbo/common/config/__init__.py b/dubbo/common/extension/__init__.py similarity index 93% rename from dubbo/common/config/__init__.py rename to dubbo/common/extension/__init__.py index bcba37a..21d4970 100644 --- a/dubbo/common/config/__init__.py +++ b/dubbo/common/extension/__init__.py @@ -13,3 +13,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .logger_extension import get_logger, register_logger diff --git a/dubbo/common/node.py b/dubbo/common/extension/logger_extension.py similarity index 59% rename from dubbo/common/node.py rename to dubbo/common/extension/logger_extension.py index c75f9f3..07c337d 100644 --- a/dubbo/common/node.py +++ b/dubbo/common/extension/logger_extension.py @@ -13,30 +13,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Type -from dubbo.common.url import URL +from dubbo.logger import Logger +# A dictionary to store all the logger classes. +_logger_dict: Dict[str, Type[Logger]] = {} -class Node: + +def register_logger(name: str): """ - Node. + A decorator to register a logger class to the logger extension point. """ - def get_url(self) -> URL: - """ - Get URL. - :return: URL - """ - raise NotImplementedError("Method 'get_url' is not implemented.") - - def is_available(self) -> bool: - """ - Is available. - """ - raise NotImplementedError("Method 'is_available' is not implemented.") - - def destroy(self) -> None: - """ - Destroy - """ - raise NotImplementedError("Method 'destroy' is not implemented.") + def decorator(cls): + _logger_dict[name] = cls + return cls + + return decorator + + +def get_logger(name: str, *args, **kwargs) -> Logger: + """ + Get a logger instance by name. + """ + logger_cls = _logger_dict[name] + return logger_cls(*args, **kwargs) diff --git a/dubbo/common/url.py b/dubbo/common/url.py deleted file mode 100644 index b3c3594..0000000 --- a/dubbo/common/url.py +++ /dev/null @@ -1,92 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import urllib.parse as ulp - - -class URL: - - def __init__(self, - protocol: str, - host: str, - port: int, - username: str = None, - password: str = None, - path: str = None, - params: dict[str, str] = None - ): - """ - Initialize URL object. - :param protocol: protocols. - :param host: host. - :param port: port. - :param username: username. - :param password: password. - :param path: path. - :param params: parameters. - """ - self.protocol = protocol - self.host = host - self.port = port - self.username = username - if password and not username: - raise ValueError("Password must be set with username.") - self.password = password - self.path = path or '' - self.params = params or {} - - def to_str(self, encoded: bool = False) -> str: - """ - Convert URL object to URL string. - :param encoded: Whether to encode the URL, default is False. - """ - # Set username and password - auth_part = f"{self.username}:{self.password}@" if self.username or self.password else "" - # Set location - netloc = f"{auth_part}{self.host}{':' + str(self.port) if self.port else ''}" - query = ulp.urlencode(self.params) - path = self.path - - url_parts = (self.protocol, netloc, path, '', query, '') - url_str = str(ulp.urlunparse(url_parts)) - - if encoded: - url_str = ulp.quote(url_str) - - return url_str - - def __str__(self): - return self.to_str() - - -def parse_url(url: str, encoded: bool = False) -> URL: - """ - Parse URL string to URL object. - :param url: URL string. - :param encoded: Whether the URL is encoded, default is False. - :return: URL - """ - if encoded: - url = ulp.unquote(url) - parsed_url = ulp.urlparse(url) - protocol = parsed_url.scheme - host = parsed_url.hostname - port = parsed_url.port - path = parsed_url.path - params = {k: v[0] for k, v in ulp.parse_qs(parsed_url.query).items()} - username = parsed_url.username or '' - password = parsed_url.password or '' - return URL(protocol, host, port, username, password, path, params) diff --git a/dubbo/common/utils/__init__.py b/dubbo/common/utils/__init__.py deleted file mode 100644 index bcba37a..0000000 --- a/dubbo/common/utils/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/dubbo/common/utils/file_utils.py b/dubbo/common/utils/file_utils.py deleted file mode 100644 index ce98aca..0000000 --- a/dubbo/common/utils/file_utils.py +++ /dev/null @@ -1,57 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import configparser -from pathlib import Path -from typing import Dict - - -def get_dubbo_dir() -> Path: - """ - Get the dubbo directory. eg: /path/to/dubbo - """ - current_path = Path(__file__).resolve().parent - - for parent in current_path.parents: - if parent.name == "dubbo": - return parent - - raise FileNotFoundError("The 'dubbo' directory was not found in the path hierarchy.") - - -_CONFIG_DIR = get_dubbo_dir().parent / "config" - - -class IniFileUtils: - """ - Ini configuration file utils. - """ - - @staticmethod - def parse_config(file_name: str, file_dir: str = None, encoding: str = "utf-8") -> Dict[str, Dict[str, str]]: - """ - Parse the configuration file. - :param file_name: The name of the configuration file. - :param file_dir: The directory of the configuration file. - :param encoding: The encoding of the configuration file. - :return: The configuration. - """ - # get the file path - file_path = Path(file_dir) / file_name if file_dir else _CONFIG_DIR / file_name - # read the configuration file - cf = configparser.ConfigParser() - cf.read(file_path, encoding=encoding) - # get the configuration dict - return {section: dict(cf[section]) for section in cf.sections()} diff --git a/dubbo/config/__init__.py b/dubbo/config/__init__.py deleted file mode 100644 index bcba37a..0000000 --- a/dubbo/config/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/dubbo/config/application_config.py b/dubbo/config/application_config.py deleted file mode 100644 index f7df8ad..0000000 --- a/dubbo/config/application_config.py +++ /dev/null @@ -1,64 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dubbo import logger -from dubbo.common import extension - -extension_manager = extension.get_extension_manager() - - -class ApplicationConfig: - """ - Application Config - """ - - def __init__( - self, - name: str, - version: str = '', - owner: str = '', - organization: str = '', - architecture: str = '', - environment: str = '', - logger_name: str = 'loguru'): - self._name = name - self._version = version - self._owner = owner - self._organization = organization - self._architecture = architecture - self._environment = environment - self._logger_name = logger_name - - def do_init(self): - # init logger - logger.set_logger_by_name(self.logger_name) - - @property - def logger_name(self): - return self._logger_name - - @logger_name.setter - def logger_name(self, logger_name: str): - self._logger_name = logger_name - logger.set_logger_by_name(logger_name) - - def __repr__(self): - return (f"") diff --git a/dubbo/config/config_manger.py b/dubbo/config/config_manger.py deleted file mode 100644 index 11fc536..0000000 --- a/dubbo/config/config_manger.py +++ /dev/null @@ -1,40 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dubbo.config.application_config import ApplicationConfig - - -class ConfigManager: - """ - Configuration manager. - """ - # unique config in application - unique_config_types = [ - ApplicationConfig, - ] - - def __init__(self): - self._configs_cache = {} - - def add_config(self, config): - """ - Add configuration. - :param config: configuration. - """ - if type(config) not in self.unique_config_types or config.__class__ not in self._configs_cache: - self._configs_cache[type(config)] = config - else: - raise ValueError(f"Config type {type(config)} already exists.") diff --git a/dubbo/config/protocol_config.py b/dubbo/config/protocol_config.py deleted file mode 100644 index 09f09b9..0000000 --- a/dubbo/config/protocol_config.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -class ProtocolConfig: - """ - Protocol Config - """ - - def __init__(self): - # protocol name - self.name = '' - # service ip address - self.host = '' - # service port - self.port = None - # protocol codec - self.codec = '' - # serialization - self.serialization = '' - # charset - self.charset = '' - # ssl - self.ssl = False - # transporter - self.transporter = '' - # server - self.server = '' - # client - self.client = '' - # register - self.register = False diff --git a/dubbo/config/reference_config.py b/dubbo/config/reference_config.py deleted file mode 100644 index f364eda..0000000 --- a/dubbo/config/reference_config.py +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class ReferenceConfig: - """ - ReferenceConfig is the configuration of service consumer. - """ - - def __init__(self): - # A particular Protocol implementation is determined by the protocol attribute in the URL. - self.protocol = None - # A ProxyFactory implementation that will generate a reference service's proxy - self.pxy = None - # The interface of the reference service - self.method = None - # The interface proxy reference - self.ref = None - # The invoker of the reference service - self.invoker = None - # The flag whether the ReferenceConfig has been initialized - self.initialized = False diff --git a/dubbo/logger/__init__.py b/dubbo/logger/__init__.py index e4e637a..3ff3c93 100644 --- a/dubbo/logger/__init__.py +++ b/dubbo/logger/__init__.py @@ -13,5 +13,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from ._logger import Logger, set_logger, set_logger_by_name, get_logger +from ._logger import Logger diff --git a/dubbo/logger/_logger.py b/dubbo/logger/_logger.py index c0542b9..4f0a279 100644 --- a/dubbo/logger/_logger.py +++ b/dubbo/logger/_logger.py @@ -13,6 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + class Logger: """ @@ -20,77 +22,63 @@ class Logger: All loggers should implement this interface. """ - def log(self, level: str, msg: str) -> None: + def __init__(self, name: str, *args, **kwargs): """ - Log + Initialize the logger. """ - raise NotImplementedError("Method 'log' is not implemented.") + pass - def debug(self, msg: str) -> None: + @classmethod + def get_logger(cls, name: str) -> "Logger": """ - Debug log + Get the logger by name. """ - raise NotImplementedError("Method 'debug' is not implemented.") + raise NotImplementedError("get_logger() is not implemented.") - def info(self, msg: str) -> None: + def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: """ - Info log + Log a message. """ - raise NotImplementedError("Method 'info' is not implemented.") + raise NotImplementedError("log() is not implemented.") - def warning(self, msg: str) -> None: + def debug(self, msg: str, *args, **kwargs) -> None: """ - Warning log + Log a debug message. """ - raise NotImplementedError("Method 'warning' is not implemented.") + raise NotImplementedError("debug() is not implemented.") - def error(self, msg: str) -> None: + def info(self, msg: str, *args, **kwargs) -> None: """ - Error log + Log an info message. """ - raise NotImplementedError("Method 'error' is not implemented.") + raise NotImplementedError("info() is not implemented.") - def critical(self, msg: str) -> None: + def warning(self, msg: str, *args, **kwargs) -> None: """ - Critical log + Log a warning message. """ - raise NotImplementedError("Method 'critical' is not implemented.") + raise NotImplementedError("warning() is not implemented.") - def exception(self, msg: str) -> None: + def error(self, msg: str, *args, **kwargs) -> None: """ - Exception log + Log an error message. """ - raise NotImplementedError("Method 'exception' is not implemented.") - - -# global logger, default logger is Logger(), so it will raise an error if it is not set -_LOGGER: Logger = Logger() + raise NotImplementedError("error() is not implemented.") + def critical(self, msg: str, *args, **kwargs) -> None: + """ + Log a critical message. + """ + raise NotImplementedError("critical() is not implemented.") -def set_logger(logger: Logger) -> None: - """ - Set global logger - """ - global _LOGGER - if logger is not None and isinstance(logger, Logger): - _LOGGER = logger - else: - raise ValueError("Invalid logger") - - -def set_logger_by_name(logger_name: str) -> None: - """ - Set global logger by name - """ - # import extension module here to avoid circular import - from dubbo.common import extension - extension_manager = extension.get_extension_manager() - instance = extension_manager.get_extension_loader(Logger).get_instance(logger_name) - set_logger(instance) - + def fatal(self, msg: str, *args, **kwargs) -> None: + """ + Log a fatal message. + """ + raise NotImplementedError("fatal() is not implemented.") -def get_logger() -> Logger: - """ - Get global logger - """ - return _LOGGER + def exception(self, msg: str, *args, **kwargs) -> None: + """ + Log an exception message. + """ + raise NotImplementedError("exception() is not implemented.") diff --git a/dubbo/logger/internal_logger.py b/dubbo/logger/internal_logger.py new file mode 100644 index 0000000..9c9f8a4 --- /dev/null +++ b/dubbo/logger/internal_logger.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Any, Dict + +from dubbo.common import extension +from dubbo.logger import Logger + + +@extension.register_logger(name="internal") +class InternalLogger(Logger): + + _loggers: Dict[str, "InternalLogger"] = {} + + def __init__(self, name: str, *args, **kwargs): + super().__init__(name, *args, **kwargs) + self._logger = logging.getLogger(name) + # Set the default log format. + handler = logging.StreamHandler() + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(module)s:%(funcName)s:%(lineno)d - [Dubbo] %(message)s" + ) + handler.setFormatter(formatter) + self._logger.addHandler(handler) + + @classmethod + def get_logger(cls, name: str) -> "Logger": + logger_instance = cls._loggers.get(name, None) + if logger_instance is None: + logger_instance = cls(name) + cls._loggers[name] = logger_instance + return logger_instance + + def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: + self._logger.log(level, msg, *args, **kwargs) + + def debug(self, msg: str, *args, **kwargs) -> None: + self._logger.debug(msg, *args, **kwargs) + + def info(self, msg: str, *args, **kwargs) -> None: + self._logger.info(msg, *args, **kwargs) + + def warning(self, msg: str, *args, **kwargs) -> None: + self._logger.warning(msg, *args, **kwargs) + + def error(self, msg: str, *args, **kwargs) -> None: + self._logger.error(msg, *args, **kwargs) + + def critical(self, msg: str, *args, **kwargs) -> None: + self._logger.critical(msg, *args, **kwargs) + + def fatal(self, msg: str, *args, **kwargs) -> None: + self._logger.fatal(msg, *args, **kwargs) + + def exception(self, msg: str, *args, **kwargs) -> None: + self._logger.exception(msg, *args, **kwargs) diff --git a/dubbo/logger/loguru_logger.py b/dubbo/logger/loguru_logger.py deleted file mode 100644 index 12e62c2..0000000 --- a/dubbo/logger/loguru_logger.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from loguru import logger - -from dubbo.logger import Logger - - -class LoguruLogger(Logger): - """ - Loguru logger. - """ - - def __init__(self): - self.logger = logger.opt(depth=1) - - def log(self, level: str, msg: str) -> None: - self.logger.log(level, msg) - - def debug(self, msg: str) -> None: - self.logger.debug(msg) - - def info(self, msg: str) -> None: - self.logger.info(msg) - - def warning(self, msg: str) -> None: - self.logger.warning(msg) - - def error(self, msg: str) -> None: - self.logger.error(msg) - - def critical(self, msg: str) -> None: - self.logger.critical(msg) - - def exception(self, msg: str) -> None: - self.logger.exception(msg) diff --git a/dubbo/protocols/__init__.py b/dubbo/protocols/__init__.py deleted file mode 100644 index bcba37a..0000000 --- a/dubbo/protocols/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/dubbo/protocols/invoker.py b/dubbo/protocols/invoker.py deleted file mode 100644 index 14c9f29..0000000 --- a/dubbo/protocols/invoker.py +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dubbo.common.node import Node - - -class Invoker(Node): - """ - Invoker. - """ - - def get_interface(self): - """ - Get service interface. - """ - raise NotImplementedError("Method 'get_interface' is not implemented.") - - def invoke(self): - """ - Invoke. - """ - raise NotImplementedError("Method 'invoke' is not implemented.") diff --git a/dubbo/protocols/protocol.py b/dubbo/protocols/protocol.py deleted file mode 100644 index a6df8da..0000000 --- a/dubbo/protocols/protocol.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dubbo.common.url import URL -from dubbo.protocols.invoker import Invoker - - -class Protocol: - """ - RPC Protocol extension interface, which encapsulates the details of remote invocation. - """ - - def export(self, invoker: Invoker): - """ - Export service for remote invocation - :param invoker: service invoker - """ - raise NotImplementedError("Method 'export' is not implemented.") - - def refer(self, service_type, url: URL): - """ - Refer a remote service. - :param service_type: service class - :param url: URL address for the remote service - """ - raise NotImplementedError("Method 'refer' is not implemented.") diff --git a/dubbo/protocols/triple/__init__.py b/dubbo/protocols/triple/__init__.py deleted file mode 100644 index bcba37a..0000000 --- a/dubbo/protocols/triple/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/dubbo/protocols/triple/triple_protocol.py b/dubbo/protocols/triple/triple_protocol.py deleted file mode 100644 index 85357b8..0000000 --- a/dubbo/protocols/triple/triple_protocol.py +++ /dev/null @@ -1,29 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dubbo.protocols.protocol import Protocol - - -class TripleProtocol(Protocol): - """ - Triple protocols. - """ - - def export(self, invoker): - raise NotImplementedError('export method is not implemented') - - def refer(self, service_type, url): - raise NotImplementedError('refer method is not implemented') diff --git a/dubbo/protocols/invocation.py b/dubbo/run.py similarity index 92% rename from dubbo/protocols/invocation.py rename to dubbo/run.py index 54a1481..5da4bd6 100644 --- a/dubbo/protocols/invocation.py +++ b/dubbo/run.py @@ -14,5 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -class Invocation: + +class Dubbo: + """The entry point of dubbo-python framework.""" + pass diff --git a/requirements.txt b/requirements.txt index a38bb99..e69de29 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +0,0 @@ -loguru~=0.7.2 \ No newline at end of file diff --git a/tests/common/__init__.py b/tests/common/__init__.py deleted file mode 100644 index bcba37a..0000000 --- a/tests/common/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/common/test_extension.py b/tests/common/test_extension.py deleted file mode 100644 index 63fc929..0000000 --- a/tests/common/test_extension.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from dubbo.common import extension -from dubbo import logger - - -class TestExtension(unittest.TestCase): - - def test_get_instance(self): - manager = extension.get_extension_manager() - assert manager is not None - loader = manager.get_extension_loader(logger.Logger) - assert loader is not None - dubbo_logger = loader.get_instance("loguru") - assert dubbo_logger is not None diff --git a/tests/common/test_url.py b/tests/common/test_url.py deleted file mode 100644 index 09ac1ef..0000000 --- a/tests/common/test_url.py +++ /dev/null @@ -1,78 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from dubbo.common import url as dubbo_url - - -class TestURL(unittest.TestCase): - - def test_parse_url_with_params(self): - url = "registry://192.168.1.7:9090/org.apache.dubbo.service1?param1=value1¶m2=value2" - parsed = dubbo_url.parse_url(url) - self.assertEqual(parsed.protocol, "registry") - self.assertEqual(parsed.host, "192.168.1.7") - self.assertEqual(parsed.port, 9090) - self.assertEqual(parsed.path, "/org.apache.dubbo.service1") - self.assertEqual(parsed.params, {"param1": "value1", "param2": "value2"}) - self.assertEqual(parsed.username, "") - self.assertEqual(parsed.password, "") - self.assertEqual(parsed.to_str(), url) - - def test_parse_url_with_auth(self): - url = "http://username:password@10.20.130.230:8080/list?version=1.0.0" - parsed = dubbo_url.parse_url(url) - self.assertEqual(parsed.protocol, "http") - self.assertEqual(parsed.host, "10.20.130.230") - self.assertEqual(parsed.port, 8080) - self.assertEqual(parsed.path, "/list") - self.assertEqual(parsed.params, {"version": "1.0.0"}) - self.assertEqual(parsed.username, "username") - self.assertEqual(parsed.password, "password") - self.assertEqual(parsed.to_str(), url) - - def test_to_str_with_encoded(self): - url = "http://username:password@10.20.130.230:8080/list?version=1.0.0" - parsed = dubbo_url.parse_url(url) - encoded_url = parsed.to_str(encoded=True) - self.assertNotEqual(encoded_url, url) - self.assertTrue('%3F' in encoded_url) - - def test_to_str_without_params(self): - url = "http://www.example.com" - parsed = dubbo_url.parse_url(url) - self.assertEqual(parsed.protocol, "http") - self.assertEqual(parsed.host, "www.example.com") - self.assertEqual(parsed.path, "") - self.assertEqual(parsed.params, {}) - self.assertEqual(parsed.username, "") - self.assertEqual(parsed.password, "") - self.assertEqual(parsed.to_str(), "http://www.example.com") - - def test_parse_url_encoded(self): - encoded_url = "http%3A%2F%2Fwww.facebook.com%2Ffriends%3Fparam1%3Dvalue1%26param2%3Dvalue2" - parsed = dubbo_url.parse_url(encoded_url, encoded=True) - self.assertEqual(parsed.protocol, "http") - self.assertEqual(parsed.host, "www.facebook.com") - self.assertEqual(parsed.path, "/friends") - self.assertEqual(parsed.params, {"param1": "value1", "param2": "value2"}) - self.assertEqual(parsed.username, "") - self.assertEqual(parsed.password, "") - self.assertEqual(parsed.to_str(), "http://www.facebook.com/friends?param1=value1¶m2=value2") - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/config/__init__.py b/tests/config/__init__.py deleted file mode 100644 index bcba37a..0000000 --- a/tests/config/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/config/test_application_config.py b/tests/config/test_application_config.py deleted file mode 100644 index 3c49553..0000000 --- a/tests/config/test_application_config.py +++ /dev/null @@ -1,32 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import unittest - -from dubbo.config.application_config import ApplicationConfig -from dubbo import logger - - -class TestApplicationConfig(unittest.TestCase): - - def test_init_logger(self): - config = ApplicationConfig(name='dubbo') - config.do_init() - dubbo_logger = logger.get_logger() - dubbo_logger.debug('debug') - dubbo_logger.info('info') - dubbo_logger.warning('warning') - dubbo_logger.error('error') - assert True diff --git a/tests/logger/test_loguru_logger.py b/tests/logger/test_internal_logger.py similarity index 65% rename from tests/logger/test_loguru_logger.py rename to tests/logger/test_internal_logger.py index 849fc58..5a3167a 100644 --- a/tests/logger/test_loguru_logger.py +++ b/tests/logger/test_internal_logger.py @@ -13,23 +13,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import unittest -from dubbo.logger.loguru_logger import LoguruLogger +from dubbo.logger.internal_logger import InternalLogger -class TestLoguruLogger(unittest.TestCase): +class TestInternalLogger(unittest.TestCase): - def test_loguru_logger(self): - logger = LoguruLogger() - logger.debug("Debug log") - logger.info("Info log") - logger.warning("Warning log") - logger.error("Error log") - logger.critical("Critical log") + def test_log(self): + logger = InternalLogger.get_logger("test") + logger.log(10, "test log") + logger.debug("test debug") + logger.info("test info") + logger.warning("test warning") + logger.error("test error") try: - return 1 / 0 + 1 / 0 except ZeroDivisionError: - logger.exception("exception!!!") - assert True + logger.exception("test exception") + self.assertTrue(True) From 8f0a1556ee189b5a970fd5f5213e7b1f23bd05da Mon Sep 17 00:00:00 2001 From: zaki Date: Thu, 13 Jun 2024 22:33:08 +0800 Subject: [PATCH 14/32] feat: add logger extension --- .flake8 | 9 +- dubbo/__init__.py | 1 + dubbo/{run.py => _dubbo.py} | 0 dubbo/common/extension/__init__.py | 2 +- dubbo/common/extension/logger_extension.py | 46 +++- dubbo/imports.py | 19 ++ dubbo/logger/__init__.py | 2 +- dubbo/logger/_logger.py | 203 ++++++++++++++++-- dubbo/logger/internal_logger.py | 126 ++++++++--- tests/common/__init__.py | 15 ++ tests/common/extension/__init__.py | 15 ++ .../common/extension/test_logger_extension.py | 33 +++ tests/logger/test_internal_logger.py | 23 +- tests/test_dubbo.py | 24 +++ 14 files changed, 452 insertions(+), 66 deletions(-) rename dubbo/{run.py => _dubbo.py} (100%) create mode 100644 dubbo/imports.py create mode 100644 tests/common/__init__.py create mode 100644 tests/common/extension/__init__.py create mode 100644 tests/common/extension/test_logger_extension.py create mode 100644 tests/test_dubbo.py diff --git a/.flake8 b/.flake8 index 6aa0376..f5b3b3c 100644 --- a/.flake8 +++ b/.flake8 @@ -19,9 +19,12 @@ exclude = .idea, .git, __pycache__, - docs + docs, + tests per-file-ignores = __init__.py:F401 - dubbo/imports/imports.py:F401 - dubbo/pydubbo.py:F401 + # module level import not at top of file + dubbo/imports.py:F401 + # module level import not at top of file + dubbo/common/extension/logger_extension.py:E402 diff --git a/dubbo/__init__.py b/dubbo/__init__.py index bcba37a..87db198 100644 --- a/dubbo/__init__.py +++ b/dubbo/__init__.py @@ -13,3 +13,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from ._dubbo import Dubbo diff --git a/dubbo/run.py b/dubbo/_dubbo.py similarity index 100% rename from dubbo/run.py rename to dubbo/_dubbo.py diff --git a/dubbo/common/extension/__init__.py b/dubbo/common/extension/__init__.py index 21d4970..c3ee8fe 100644 --- a/dubbo/common/extension/__init__.py +++ b/dubbo/common/extension/__init__.py @@ -13,4 +13,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .logger_extension import get_logger, register_logger +from .logger_extension import get_logger_adapter, register_logger_adapter diff --git a/dubbo/common/extension/logger_extension.py b/dubbo/common/extension/logger_extension.py index 07c337d..998f029 100644 --- a/dubbo/common/extension/logger_extension.py +++ b/dubbo/common/extension/logger_extension.py @@ -13,29 +13,55 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Type -from dubbo.logger import Logger +""" +This module provides an extension point for logger adapters. +Note: Type annotations are not fully used here (LoggerAdapter object is not explicitly specified) +because it would cause a circular reference issue. +""" -# A dictionary to store all the logger classes. -_logger_dict: Dict[str, Type[Logger]] = {} +# A dictionary to store all the logger adapters. key: name, value: logger adapter class +_logger_adapter_dict = {} -def register_logger(name: str): +def register_logger_adapter(name: str): """ A decorator to register a logger class to the logger extension point. + + This function returns a decorator that registers the decorated class + as a logger adapter under the specified name. + + Args: + name (str): The name to register the logger adapter under. + + Returns: + Callable[[Type[LoggerAdapter]], Type[LoggerAdapter]]: + A decorator function that registers the logger class. """ def decorator(cls): - _logger_dict[name] = cls + _logger_adapter_dict[name] = cls return cls return decorator -def get_logger(name: str, *args, **kwargs) -> Logger: +def get_logger_adapter(name: str, *args, **kwargs): """ - Get a logger instance by name. + Get a logger adapter instance by name. + + This function retrieves a logger adapter class by its registered name and + instantiates it with the provided arguments. + + Args: + name (str): The name of the logger adapter to retrieve. + *args: Variable length argument list for the logger adapter constructor. + **kwargs: Arbitrary keyword arguments for the logger adapter constructor. + + Returns: + LoggerAdapter: An instance of the requested logger adapter. + Raises: + KeyError: If no logger adapter is registered under the provided name. """ - logger_cls = _logger_dict[name] - return logger_cls(*args, **kwargs) + logger_adapter = _logger_adapter_dict[name] + return logger_adapter(*args, **kwargs) diff --git a/dubbo/imports.py b/dubbo/imports.py new file mode 100644 index 0000000..1e860c9 --- /dev/null +++ b/dubbo/imports.py @@ -0,0 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilizing the mechanism of module loading to complete the registration of plugins.""" + +import dubbo.logger.internal_logger diff --git a/dubbo/logger/__init__.py b/dubbo/logger/__init__.py index 3ff3c93..de344ef 100644 --- a/dubbo/logger/__init__.py +++ b/dubbo/logger/__init__.py @@ -13,4 +13,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._logger import Logger +from ._logger import Level, Logger, LoggerAdapter, LoggerFactory diff --git a/dubbo/logger/_logger.py b/dubbo/logger/_logger.py index 4f0a279..865fb73 100644 --- a/dubbo/logger/_logger.py +++ b/dubbo/logger/_logger.py @@ -13,72 +13,243 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +import enum +import threading +from typing import Any, Dict + +from dubbo.common import extension + + +@enum.unique +class Level(enum.Enum): + """ + The logging level enum. + """ + + DEBUG = "DEBUG" + INFO = "INFO" + WARNING = "WARNING" + ERROR = "ERROR" + CRITICAL = "CRITICAL" + FATAL = "FATAL" class Logger: """ Logger Interface, which is used to log messages. - All loggers should implement this interface. """ - def __init__(self, name: str, *args, **kwargs): - """ - Initialize the logger. - """ - pass - - @classmethod - def get_logger(cls, name: str) -> "Logger": - """ - Get the logger by name. + def log(self, level: Level, msg: str, *args: Any, **kwargs: Any) -> None: """ - raise NotImplementedError("get_logger() is not implemented.") + Log a message at the specified logging level. - def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: - """ - Log a message. + Args: + level (Level): The logging level. + msg (str): The log message. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. """ raise NotImplementedError("log() is not implemented.") def debug(self, msg: str, *args, **kwargs) -> None: """ Log a debug message. + + Args: + msg (str): The debug message. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. """ raise NotImplementedError("debug() is not implemented.") def info(self, msg: str, *args, **kwargs) -> None: """ Log an info message. + + Args: + msg (str): The info message. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. """ raise NotImplementedError("info() is not implemented.") def warning(self, msg: str, *args, **kwargs) -> None: """ Log a warning message. + + Args: + msg (str): The warning message. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. """ raise NotImplementedError("warning() is not implemented.") def error(self, msg: str, *args, **kwargs) -> None: """ Log an error message. + + Args: + msg (str): The error message. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. """ raise NotImplementedError("error() is not implemented.") def critical(self, msg: str, *args, **kwargs) -> None: """ Log a critical message. + + Args: + msg (str): The critical message. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. """ raise NotImplementedError("critical() is not implemented.") def fatal(self, msg: str, *args, **kwargs) -> None: """ Log a fatal message. + + Args: + msg (str): The fatal message. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. """ raise NotImplementedError("fatal() is not implemented.") def exception(self, msg: str, *args, **kwargs) -> None: """ Log an exception message. + + Args: + msg (str): The exception message. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. """ raise NotImplementedError("exception() is not implemented.") + + +class LoggerAdapter: + """ + Logger Adapter Interface, which is used to support different logging libraries. + """ + + def __init__(self, *args, **kwargs): + """ + Initialize the logger adapter. + + Args: + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. + """ + pass + + def get_logger(self, name: str) -> Logger: + """ + Get a logger by name. + + Args: + name (str): The name of the logger. + + Returns: + Logger: An instance of the logger. + """ + raise NotImplementedError("get_logger() is not implemented.") + + @property + def level(self) -> Level: + """ + Get the current logging level. + + Returns: + Level: The current logging level. + """ + raise NotImplementedError("get_level() is not implemented.") + + @level.setter + def level(self, level: Level) -> None: + """ + Set the logging level. + + Args: + level (Level): The logging level to set. + """ + raise NotImplementedError("set_level() is not implemented.") + + +class LoggerFactory: + """ + Factory class to create loggers. + """ + + # The logger adapter. + _logger_adapter: LoggerAdapter + + # A dictionary to store all the loggers. + _loggers: Dict[str, Logger] = {} + + # A lock to protect the loggers. + _logger_lock = threading.Lock() + + @classmethod + def get_logger_adapter(cls) -> LoggerAdapter: + """ + Get the logger adapter. + + Returns: + LoggerAdapter: The current logger adapter. + """ + return cls._logger_adapter + + @classmethod + def set_logger_adapter(cls, logger_adapter: str) -> None: + """ + Set the logger adapter. + + Args: + logger_adapter (str): The name of the logger adapter to set. + """ + cls._logger_adapter = extension.get_logger_adapter(logger_adapter) + # update all loggers + cls._loggers = { + name: cls._logger_adapter.get_logger(name) for name in cls._loggers + } + + @classmethod + def get_logger(cls, name: str) -> Logger: + """ + Get the logger by name. + + Args: + name (str): The name of the logger to retrieve. + + Returns: + Logger: An instance of the requested logger. + """ + logger = cls._loggers.get(name) + if logger is None: + with cls._logger_lock: + if name not in cls._loggers: + cls._loggers[name] = cls._logger_adapter.get_logger(name) + logger = cls._loggers[name] + return logger + + @classmethod + def set_level(cls, level: Level) -> None: + """ + Set the logging level. + + Args: + level (Level): The logging level to set. + """ + cls._logger_adapter.level = level + + @classmethod + def get_level(cls) -> Level: + """ + Get the current logging level. + + Returns: + Level: The current logging level. + """ + return cls._logger_adapter.level diff --git a/dubbo/logger/internal_logger.py b/dubbo/logger/internal_logger.py index 9c9f8a4..5aa6c87 100644 --- a/dubbo/logger/internal_logger.py +++ b/dubbo/logger/internal_logger.py @@ -13,57 +13,121 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging -from typing import Any, Dict +from typing import Dict from dubbo.common import extension -from dubbo.logger import Logger +from dubbo.logger import Level, Logger, LoggerAdapter + +"""This module provides the internal logger implementation. -> logging module""" + +# The mapping from the logging level to the internal logging level. +_level_map: Dict[Level, int] = { + Level.DEBUG: logging.DEBUG, + Level.INFO: logging.INFO, + Level.WARNING: logging.WARNING, + Level.ERROR: logging.ERROR, + Level.CRITICAL: logging.CRITICAL, + Level.FATAL: logging.FATAL, +} -@extension.register_logger(name="internal") class InternalLogger(Logger): + """ + The internal logger implementation. + """ - _loggers: Dict[str, "InternalLogger"] = {} + def __init__(self, internal_logger: logging.Logger): + self._logger = internal_logger - def __init__(self, name: str, *args, **kwargs): - super().__init__(name, *args, **kwargs) - self._logger = logging.getLogger(name) - # Set the default log format. - handler = logging.StreamHandler() - formatter = logging.Formatter( - fmt="%(asctime)s | %(levelname)s | %(module)s:%(funcName)s:%(lineno)d - [Dubbo] %(message)s" - ) - handler.setFormatter(formatter) - self._logger.addHandler(handler) - - @classmethod - def get_logger(cls, name: str) -> "Logger": - logger_instance = cls._loggers.get(name, None) - if logger_instance is None: - logger_instance = cls(name) - cls._loggers[name] = logger_instance - return logger_instance - - def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: + def _log(self, level: int, msg: str, *args, **kwargs) -> None: + # Add the stacklevel to the keyword arguments. + kwargs["stacklevel"] = kwargs.get("stacklevel", 1) + 2 self._logger.log(level, msg, *args, **kwargs) + def log(self, level: Level, msg: str, *args, **kwargs) -> None: + self._log(_level_map[level], msg, *args, **kwargs) + def debug(self, msg: str, *args, **kwargs) -> None: - self._logger.debug(msg, *args, **kwargs) + self._log(logging.DEBUG, msg, *args, **kwargs) def info(self, msg: str, *args, **kwargs) -> None: - self._logger.info(msg, *args, **kwargs) + self._log(logging.INFO, msg, *args, **kwargs) def warning(self, msg: str, *args, **kwargs) -> None: - self._logger.warning(msg, *args, **kwargs) + self._log(logging.WARNING, msg, *args, **kwargs) def error(self, msg: str, *args, **kwargs) -> None: - self._logger.error(msg, *args, **kwargs) + self._log(logging.ERROR, msg, *args, **kwargs) def critical(self, msg: str, *args, **kwargs) -> None: - self._logger.critical(msg, *args, **kwargs) + self._log(logging.CRITICAL, msg, *args, **kwargs) def fatal(self, msg: str, *args, **kwargs) -> None: - self._logger.fatal(msg, *args, **kwargs) + self._log(logging.FATAL, msg, *args, **kwargs) def exception(self, msg: str, *args, **kwargs) -> None: - self._logger.exception(msg, *args, **kwargs) + if kwargs.get("exc_info") is None: + kwargs["exc_info"] = True + self.error(msg, *args, **kwargs) + + +@extension.register_logger_adapter("internal") +class InternalLoggerAdapter(LoggerAdapter): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Set the default logging level to DEBUG. + self._level = Level.DEBUG + self._update_level(Level.DEBUG) + + def get_logger(self, name: str) -> Logger: + """ + Create a logger instance by name. + Args: + name (str): The logger name. + Returns: + Logger: The InternalLogger instance. + """ + # TODO enable config by args + logger_instance = logging.getLogger(name) + # Create a formatter + formatter = logging.Formatter( + "%(asctime)s | %(levelname)s | %(module)s:%(funcName)s:%(lineno)d - [Dubbo] %(message)s" + ) + # Add a console handler + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logger_instance.addHandler(console_handler) + return InternalLogger(logger_instance) + + @property + def level(self) -> Level: + """ + Get the logging level. + Returns: + Level: The logging level. + """ + return self._level + + @level.setter + def level(self, level: Level) -> None: + """ + Set the logging level. + Args: + level (Level): The logging level. + """ + if level == self._level or level is None: + return + self._level = level + self._update_level(level) + + def _update_level(self, level: Level) -> None: + """ + Update the logging level. + """ + # Get the root logger + root_logger = logging.getLogger() + # Set the logging level + root_logger.setLevel(level.name) diff --git a/tests/common/__init__.py b/tests/common/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/tests/common/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/common/extension/__init__.py b/tests/common/extension/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/tests/common/extension/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/common/extension/test_logger_extension.py b/tests/common/extension/test_logger_extension.py new file mode 100644 index 0000000..96a50c0 --- /dev/null +++ b/tests/common/extension/test_logger_extension.py @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + + +class TestLoggerExtension(unittest.TestCase): + + def test_logger_extension(self): + import dubbo.imports + from dubbo.common import extension + + # Test the get_logger_adapter method. + logger_adapter = extension.get_logger_adapter("internal") + + # Test logger_adapter methods. + logger = logger_adapter.get_logger("test") + logger.debug("test debug") + logger.info("test info") + logger.warning("test warning") + logger.error("test error") \ No newline at end of file diff --git a/tests/logger/test_internal_logger.py b/tests/logger/test_internal_logger.py index 5a3167a..3f32a36 100644 --- a/tests/logger/test_internal_logger.py +++ b/tests/logger/test_internal_logger.py @@ -15,20 +15,35 @@ # limitations under the License. import unittest -from dubbo.logger.internal_logger import InternalLogger +from dubbo.logger import Level +from dubbo.logger.internal_logger import InternalLoggerAdapter class TestInternalLogger(unittest.TestCase): def test_log(self): - logger = InternalLogger.get_logger("test") - logger.log(10, "test log") + logger_adapter = InternalLoggerAdapter() + logger = logger_adapter.get_logger("test") + logger.log(Level.INFO, "test log") logger.debug("test debug") logger.info("test info") logger.warning("test warning") logger.error("test error") + logger.critical("test critical") + logger.fatal("test fatal") try: 1 / 0 except ZeroDivisionError: logger.exception("test exception") - self.assertTrue(True) + + # test different default logger level + logger_adapter.level = Level.INFO + logger.debug("debug can't be logged") + + logger_adapter.level = Level.WARNING + logger.info("info can't be logged") + + logger_adapter.level = Level.ERROR + logger.warning("warning can't be logged") + + diff --git a/tests/test_dubbo.py b/tests/test_dubbo.py new file mode 100644 index 0000000..a9cdebd --- /dev/null +++ b/tests/test_dubbo.py @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + + +class TestDubbo(unittest.TestCase): + + def test_dubbo(self): + from dubbo import Dubbo + + Dubbo() From 2def13b97561c8a797d9ea1aa742f6a16ef2da09 Mon Sep 17 00:00:00 2001 From: zaki Date: Fri, 14 Jun 2024 18:30:51 +0800 Subject: [PATCH 15/32] feat: add url --- dubbo/common/url.py | 328 ++++++++++++++++++++++++++++++++ dubbo/config/__init__.py | 15 ++ dubbo/config/logger_config.py | 87 +++++++++ dubbo/logger/__init__.py | 2 +- dubbo/logger/_logger.py | 14 ++ dubbo/logger/internal_logger.py | 21 +- tests/common/tets_url.py | 78 ++++++++ 7 files changed, 541 insertions(+), 4 deletions(-) create mode 100644 dubbo/common/url.py create mode 100644 dubbo/config/__init__.py create mode 100644 dubbo/config/logger_config.py create mode 100644 tests/common/tets_url.py diff --git a/dubbo/common/url.py b/dubbo/common/url.py new file mode 100644 index 0000000..739a3a7 --- /dev/null +++ b/dubbo/common/url.py @@ -0,0 +1,328 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional +from urllib import parse + + +class URL: + """ + URL - Uniform Resource Locator + + url example: + - http://www.facebook.com/friends?param1=value1¶m2=value2 + - http://username:password@10.20.130.230:8080/list?version=1.0.0 + - ftp://username:password@192.168.1.7:21/1/read.txt + - registry://192.168.1.7:9090/org.apache.dubbo.service1?param1=value1¶m2=value2 + """ + + def __init__( + self, + protocol: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + username: Optional[str] = None, + password: Optional[str] = None, + path: Optional[str] = None, + params: Optional[Dict[str, str]] = None, + ): + """ + Initializes the URL with the given components. + + Args: + protocol (Optional[str]): The protocol of the URL. + host (Optional[str]): The host of the URL. + port (Optional[int]): The port number of the URL. + username (Optional[str]): The username for URL authentication. + password (Optional[str]): The password for URL authentication. + path (Optional[str]): The path of the URL. + params (Optional[Dict[str, str]]): The query parameters of the URL. + """ + self._protocol = protocol + self._host = host + self._port = port + # address = host:port + self._address = None if not host else f"{host}:{port}" if port else host + self._username = username + self._password = password + self._path = path + self._params = params + + @property + def protocol(self) -> Optional[str]: + """ + Gets the protocol of the URL. + + Returns: + Optional[str]: The protocol of the URL. + """ + return self._protocol + + @protocol.setter + def protocol(self, protocol: str) -> None: + """ + Sets the protocol of the URL. + + Args: + protocol (str): The protocol to set. + """ + self._protocol = protocol + + @property + def address(self) -> Optional[str]: + """ + Gets the address (host:port) of the URL. + + Returns: + Optional[str]: The address of the URL. + """ + return self._address + + @address.setter + def address(self, address: str) -> None: + """ + Sets the address (host:port) of the URL. + + Args: + address (str): The address to set. + """ + self._address = address + if ":" in address: + self._host, port = address.split(":") + self._port = int(port) + else: + self._host = address + self._port = None + + @property + def host(self) -> Optional[str]: + """ + Gets the host of the URL. + + Returns: + Optional[str]: The host of the URL. + """ + return self._host + + @host.setter + def host(self, host: str) -> None: + """ + Sets the host of the URL. + + Args: + host (str): The host to set. + """ + self._host = host + self._address = f"{host}:{self.port}" if self.port else host + + @property + def port(self) -> Optional[int]: + """ + Gets the port of the URL. + + Returns: + Optional[int]: The port of the URL. + """ + return self._port + + @port.setter + def port(self, port: int) -> None: + """ + Sets the port of the URL. + + Args: + port (int): The port to set. + """ + self._port = port + self._address = f"{self.host}:{port}" if port else self.host + + @property + def username(self) -> Optional[str]: + """ + Gets the username for URL authentication. + + Returns: + Optional[str]: The username for URL authentication. + """ + return self._username + + @username.setter + def username(self, username: str) -> None: + """ + Sets the username for URL authentication. + + Args: + username (str): The username to set. + """ + self._username = username + + @property + def password(self) -> Optional[str]: + """ + Gets the password for URL authentication. + + Returns: + Optional[str]: The password for URL authentication. + """ + return self._password + + @password.setter + def password(self, password: str) -> None: + """ + Sets the password for URL authentication. + + Args: + password (str): The password to set. + """ + self._password = password + + @property + def path(self) -> Optional[str]: + """ + Gets the path of the URL. + + Returns: + Optional[str]: The path of the URL. + """ + return self._path + + @path.setter + def path(self, path: str) -> None: + """ + Sets the path of the URL. + + Args: + path (str): The path to set. + """ + self._path = path + + @property + def params(self) -> Optional[Dict[str, str]]: + """ + Gets the query parameters of the URL. + + Returns: + Optional[Dict[str, str]]: The query parameters of the URL. + """ + return self._params + + @params.setter + def params(self, params: Dict[str, str]) -> None: + """ + Sets the query parameters of the URL. + + Args: + params (Dict[str, str]): The query parameters to set. + """ + self._params = params + + def get_param(self, key: str) -> Optional[str]: + """ + Gets a query parameter from the URL. + + Args: + key (str): The parameter name. + + Returns: + str or None: The parameter value. If the parameter does not exist, returns None. + """ + return self._params.get(key, None) if self._params else None + + def add_param(self, key: str, value: str) -> None: + """ + Adds a query parameter to the URL. + + Args: + key (str): The parameter name. + value (str): The parameter value. + """ + if not self._params: + self._params = {} + self._params[key] = value + + def to_string(self, encode: bool = False) -> str: + """ + Generates the URL string based on the current components. + + Args: + encode (bool): If True, the URL will be percent-encoded. + + Returns: + str: The generated URL string. + """ + # Set protocol + url = f"{self.protocol}://" if self.protocol else "" + # Set auth + if self.username: + url += f"{self.username}" + if self.password: + url += f":{self.password}" + url += "@" + # Set Address + url += self.address if self.address else "" + # Set path + url += "/" + if self.path: + url += f"{self.path}" + # Set params + if self.params: + url += "?" + "&".join([f"{k}={v}" for k, v in self.params.items()]) + # If the URL needs to be encoded, encode it + if encode: + url = parse.quote(url) + return url + + def __str__(self) -> str: + """ + Returns the URL string when the object is converted to a string. + + Returns: + str: The generated URL string. + """ + return self.to_string() + + @classmethod + def value_of(cls, url: str, encoded: bool = False) -> "URL": + """ + Creates a URL object from a URL string. + + Args: + url (str): The URL string to parse. format: [protocol://][username:password@][host:port]/[path] + encoded (bool): If True, the URL string is percent-encoded and will be decoded. + + Returns: + URL: The created URL object. + """ + if not url: + raise ValueError() + + # If the URL is encoded, decode it + if encoded: + url = parse.unquote(url) + + if "://" not in url: + raise ValueError("Invalid URL format: missing protocol") + + parsed_url = parse.urlparse(url) + + protocol = parsed_url.scheme + host = parsed_url.hostname + port = parsed_url.port + username = parsed_url.username + password = parsed_url.password + params = {k: v[0] for k, v in parse.parse_qs(parsed_url.query).items()} + path = parsed_url.path.lstrip("/") + + return URL(protocol, host, port, username, password, path, params) diff --git a/dubbo/config/__init__.py b/dubbo/config/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/config/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/config/logger_config.py b/dubbo/config/logger_config.py new file mode 100644 index 0000000..6ea97f8 --- /dev/null +++ b/dubbo/config/logger_config.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from dataclasses import dataclass +from typing import Optional + +from dubbo.logger import Level, RotateType + + +@dataclass +class ConsoleLoggerConfig: + # default is open console logger + enabled: bool = True + # default level is None, use the global level + level: Optional[Level] = None + # default formatter is None, use the global formatter + formatter: Optional[str] = None + + +@dataclass +class FileLoggerConfig: + # default is close file logger + enabled: bool = False + # default level is None, use the global level + level: Optional[Level] = None + # default formatter is None, use the global formatter + formatter: Optional[str] = None + # default log file dir is user home dir + file_dir: Optional[str] = os.path.expanduser("~") + # default no rotate + rotate: Optional[RotateType] = RotateType.NONE + # when rotate is SIZE, max_bytes is required, default 10M + max_bytes: Optional[int] = 1024 * 1024 * 10 + # when rotate is TIME, rotation is required, unit is day, default 1 + rotation: Optional[int] = 1 + # when rotate is not NONE, backup_count is required, default 10 + backup_count: Optional[int] = 10 + + +class LoggerConfig: + + def __init__( + self, + logger: str = "internal", + level: Level = Level.INFO, + formatter: Optional[str] = None, + console_config: ConsoleLoggerConfig = ConsoleLoggerConfig(), + file_config: FileLoggerConfig = FileLoggerConfig(), + ): + # global logger config + self._logger = logger + self._default_level = level + self._default_formatter = formatter + # console logger config + self._console_config = console_config + # file logger config + self._file_config = file_config + + self._set_default_config() + + def _set_default_config(self): + # update console logger config + if self._console_config.enabled: + if self._console_config.level is None: + self._console_config.level = self._default_level + if self._console_config.formatter is None: + self._console_config.formatter = self._default_formatter + + # update file logger config + if self._file_config.enabled: + if self._file_config.level is None: + self._file_config.level = self._default_level + if self._file_config.formatter is None: + self._file_config.formatter = self._default_formatter diff --git a/dubbo/logger/__init__.py b/dubbo/logger/__init__.py index de344ef..2c05a1f 100644 --- a/dubbo/logger/__init__.py +++ b/dubbo/logger/__init__.py @@ -13,4 +13,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._logger import Level, Logger, LoggerAdapter, LoggerFactory +from ._logger import Level, Logger, LoggerAdapter, LoggerFactory, RotateType diff --git a/dubbo/logger/_logger.py b/dubbo/logger/_logger.py index 865fb73..a82bb56 100644 --- a/dubbo/logger/_logger.py +++ b/dubbo/logger/_logger.py @@ -34,6 +34,20 @@ class Level(enum.Enum): FATAL = "FATAL" +@enum.unique +class RotateType(enum.Enum): + """ + The file rotating type enum. + """ + + # No rotating. + NONE = "NONE" + # Rotate the file by size. + SIZE = "SIZE" + # Rotate the file by time. + TIME = "TIME" + + class Logger: """ Logger Interface, which is used to log messages. diff --git a/dubbo/logger/internal_logger.py b/dubbo/logger/internal_logger.py index 5aa6c87..031bdc6 100644 --- a/dubbo/logger/internal_logger.py +++ b/dubbo/logger/internal_logger.py @@ -93,9 +93,9 @@ def get_logger(self, name: str) -> Logger: # TODO enable config by args logger_instance = logging.getLogger(name) # Create a formatter - formatter = logging.Formatter( - "%(asctime)s | %(levelname)s | %(module)s:%(funcName)s:%(lineno)d - [Dubbo] %(message)s" - ) + default_format = "%(asctime)s | %(levelname)s | %(module)s:%(funcName)s:%(lineno)d - [Dubbo] %(message)s" + formatter = logging.Formatter(default_format) + # Add a console handler console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) @@ -131,3 +131,18 @@ def _update_level(self, level: Level) -> None: root_logger = logging.getLogger() # Set the logging level root_logger.setLevel(level.name) + + +if __name__ == "__main__": + logger_adapter = InternalLoggerAdapter() + logger = logger_adapter.get_logger("test") + logger.debug("test debug") + logger.info("test info") + logger.warning("test warning") + logger.error("test error") + logger.critical("test critical") + logger.fatal("test fatal") + try: + 1 / 0 + except ZeroDivisionError: + logger.exception("test exception") diff --git a/tests/common/tets_url.py b/tests/common/tets_url.py new file mode 100644 index 0000000..0f52abc --- /dev/null +++ b/tests/common/tets_url.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 1.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-1.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from dubbo.common.url import URL + + +class TestUrl(unittest.TestCase): + + def test_str_to_url(self): + url_0 = URL.value_of( + "http://www.facebook.com/friends?param1=value1¶m2=value2" + ) + self.assertEqual("http", url_0.protocol) + self.assertEqual("www.facebook.com", url_0.host) + self.assertEqual(None, url_0.port) + self.assertEqual("friends", url_0.path) + self.assertEqual("value1", url_0.get_param("param1")) + self.assertEqual("value2", url_0.get_param("param2")) + + url_1 = URL.value_of("ftp://username:password@192.168.1.7:21/1/read.txt") + self.assertEqual("ftp", url_1.protocol) + self.assertEqual("username", url_1.username) + self.assertEqual("password", url_1.password) + self.assertEqual("192.168.1.7", url_1.host) + self.assertEqual(21, url_1.port) + self.assertEqual("192.168.1.7:21", url_1.address) + self.assertEqual("1/read.txt", url_1.path) + + url_2 = URL.value_of("file:///home/user1/router.js?type=script") + self.assertEqual("file", url_2.protocol) + self.assertEqual("home/user1/router.js", url_2.path) + + url_3 = URL.value_of( + "http%3A//www.facebook.com/friends%3Fparam1%3Dvalue1%26param2%3Dvalue2", + encoded=True, + ) + self.assertEqual("http", url_3.protocol) + self.assertEqual("www.facebook.com", url_3.host) + self.assertEqual(None, url_3.port) + self.assertEqual("friends", url_3.path) + self.assertEqual("value1", url_3.get_param("param1")) + self.assertEqual("value2", url_3.get_param("param2")) + + def test_url_to_str(self): + url_0 = URL( + protocol="tri", + host="127.0.0.1", + port=12, + username="username", + password="password", + path="path", + params={"type": "a"}, + ) + self.assertEqual( + "tri://username:password@127.0.0.1:12/path?type=a", url_0.to_string() + ) + + url_1 = URL( + protocol="tri", host="127.0.0.1", port=12, path="path", params={"type": "a"} + ) + self.assertEqual("tri://127.0.0.1:12/path?type=a", url_1.to_string()) + + url_2 = URL(protocol="tri", host="127.0.0.1", port=12, params={"type": "a"}) + self.assertEqual("tri://127.0.0.1:12/?type=a", url_2.to_string()) From 4bb4f8a2efd6b608027390b6f4f444d24478af06 Mon Sep 17 00:00:00 2001 From: zaki Date: Fri, 14 Jun 2024 18:35:46 +0800 Subject: [PATCH 16/32] fix: fix ci --- tests/common/tets_url.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/common/tets_url.py b/tests/common/tets_url.py index 0f52abc..40a3604 100644 --- a/tests/common/tets_url.py +++ b/tests/common/tets_url.py @@ -2,11 +2,11 @@ # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 1.0 +# The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-1.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, From 81a06e64e33691fe87bf7bcaccaf64cf8712d690 Mon Sep 17 00:00:00 2001 From: zaki Date: Sun, 16 Jun 2024 01:57:56 +0800 Subject: [PATCH 17/32] feat: finish logger part --- .flake8 | 3 +- dubbo/__init__.py | 2 + dubbo/common/constants/__init__.py | 17 ++ dubbo/common/constants/logger_constants.py | 82 +++++++++ dubbo/common/extension/logger_extension.py | 19 +- dubbo/common/url.py | 96 ++++------ dubbo/config/__init__.py | 1 + dubbo/config/logger_config.py | 135 +++++++++----- dubbo/imports.py | 2 +- dubbo/logger/__init__.py | 8 +- dubbo/logger/internal/__init__.py | 15 ++ dubbo/logger/internal/logger.py | 75 ++++++++ dubbo/logger/internal/logger_adapter.py | 174 ++++++++++++++++++ dubbo/logger/internal_logger.py | 148 --------------- dubbo/logger/{_logger.py => logger.py} | 140 +++----------- dubbo/logger/logger_factory.py | 134 ++++++++++++++ .../common/extension/test_logger_extension.py | 11 +- tests/common/tets_url.py | 26 +-- tests/logger/test_internal_logger.py | 17 +- tests/logger/test_logger_factory.py | 49 +++++ 20 files changed, 747 insertions(+), 407 deletions(-) create mode 100644 dubbo/common/constants/__init__.py create mode 100644 dubbo/common/constants/logger_constants.py create mode 100644 dubbo/logger/internal/__init__.py create mode 100644 dubbo/logger/internal/logger.py create mode 100644 dubbo/logger/internal/logger_adapter.py delete mode 100644 dubbo/logger/internal_logger.py rename dubbo/logger/{_logger.py => logger.py} (60%) create mode 100644 dubbo/logger/logger_factory.py create mode 100644 tests/logger/test_logger_factory.py diff --git a/.flake8 b/.flake8 index f5b3b3c..233cd14 100644 --- a/.flake8 +++ b/.flake8 @@ -19,8 +19,7 @@ exclude = .idea, .git, __pycache__, - docs, - tests + docs per-file-ignores = __init__.py:F401 diff --git a/dubbo/__init__.py b/dubbo/__init__.py index 87db198..b31a846 100644 --- a/dubbo/__init__.py +++ b/dubbo/__init__.py @@ -13,4 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import dubbo.imports + from ._dubbo import Dubbo diff --git a/dubbo/common/constants/__init__.py b/dubbo/common/constants/__init__.py new file mode 100644 index 0000000..44dc90e --- /dev/null +++ b/dubbo/common/constants/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .logger_constants import (LoggerConstants, LoggerFileRotateType, + LoggerLevel) diff --git a/dubbo/common/constants/logger_constants.py b/dubbo/common/constants/logger_constants.py new file mode 100644 index 0000000..14ee10b --- /dev/null +++ b/dubbo/common/constants/logger_constants.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum +from functools import cache + + +@enum.unique +class LoggerLevel(enum.Enum): + """ + The logging level enum. + """ + + DEBUG = "DEBUG" + INFO = "INFO" + WARNING = "WARNING" + ERROR = "ERROR" + CRITICAL = "CRITICAL" + FATAL = "FATAL" + + @classmethod + @cache + def get_level(cls, level_value: str) -> "LoggerLevel": + level_value = level_value.upper() + for level in cls: + if level_value == level.value: + return level + raise ValueError("Log level invalid") + + +@enum.unique +class LoggerFileRotateType(enum.Enum): + """ + The file rotating type enum. + """ + + # No rotating. + NONE = "NONE" + # Rotate the file by size. + SIZE = "SIZE" + # Rotate the file by time. + TIME = "TIME" + + +class LoggerConstants: + """logger configuration constants.""" + + """logger config keys""" + # global config + LOGGER_LEVEL_KEY = "logger.level" + LOGGER_DRIVER_KEY = "logger.driver" + LOGGER_FORMAT_KEY = "logger.format" + + # console config + LOGGER_CONSOLE_ENABLED_KEY = "logger.console.enable" + LOGGER_CONSOLE_FORMAT_KEY = "logger.console.format" + + # file logger + LOGGER_FILE_ENABLED_KEY = "logger.file.enable" + LOGGER_FILE_FORMAT_KEY = "logger.file.format" + LOGGER_FILE_DIR_KEY = "logger.file.dir" + LOGGER_FILE_NAME_KEY = "logger.file.name" + LOGGER_FILE_ROTATE_KEY = "logger.file.rotate" + LOGGER_FILE_MAX_BYTES_KEY = "logger.file.maxbytes" + LOGGER_FILE_INTERVAL_KEY = "logger.file.interval" + LOGGER_FILE_BACKUP_COUNT_KEY = "logger.file.backupcount" + + """some logger default value""" + LOGGER_DRIVER_VALUE = "internal" + LOGGER_FILE_NAME_VALUE = "dubbo.log" diff --git a/dubbo/common/extension/logger_extension.py b/dubbo/common/extension/logger_extension.py index 998f029..71c3470 100644 --- a/dubbo/common/extension/logger_extension.py +++ b/dubbo/common/extension/logger_extension.py @@ -16,12 +16,14 @@ """ This module provides an extension point for logger adapters. -Note: Type annotations are not fully used here (LoggerAdapter object is not explicitly specified) -because it would cause a circular reference issue. """ +from typing import Dict + +from dubbo.common.url import URL +from dubbo.logger import LoggerAdapter # A dictionary to store all the logger adapters. key: name, value: logger adapter class -_logger_adapter_dict = {} +_logger_adapter_dict: Dict[str, type[LoggerAdapter]] = {} def register_logger_adapter(name: str): @@ -39,14 +41,14 @@ def register_logger_adapter(name: str): A decorator function that registers the logger class. """ - def decorator(cls): + def wrapper(cls): _logger_adapter_dict[name] = cls return cls - return decorator + return wrapper -def get_logger_adapter(name: str, *args, **kwargs): +def get_logger_adapter(name: str, config: URL) -> LoggerAdapter: """ Get a logger adapter instance by name. @@ -55,8 +57,7 @@ def get_logger_adapter(name: str, *args, **kwargs): Args: name (str): The name of the logger adapter to retrieve. - *args: Variable length argument list for the logger adapter constructor. - **kwargs: Arbitrary keyword arguments for the logger adapter constructor. + config (URL): The config of the logger adapter to retrieve. Returns: LoggerAdapter: An instance of the requested logger adapter. @@ -64,4 +65,4 @@ def get_logger_adapter(name: str, *args, **kwargs): KeyError: If no logger adapter is registered under the provided name. """ logger_adapter = _logger_adapter_dict[name] - return logger_adapter(*args, **kwargs) + return logger_adapter(config) diff --git a/dubbo/common/url.py b/dubbo/common/url.py index 739a3a7..bb78f49 100644 --- a/dubbo/common/url.py +++ b/dubbo/common/url.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional +from typing import Any, Dict, Optional from urllib import parse @@ -30,43 +30,43 @@ class URL: def __init__( self, - protocol: Optional[str] = None, - host: Optional[str] = None, - port: Optional[int] = None, + protocol: str, + host: Optional[str], + port: Optional[int], username: Optional[str] = None, password: Optional[str] = None, path: Optional[str] = None, - params: Optional[Dict[str, str]] = None, + parameters: Optional[Dict[str, str]] = None, ): """ Initializes the URL with the given components. Args: - protocol (Optional[str]): The protocol of the URL. + protocol (str): The protocol of the URL. host (Optional[str]): The host of the URL. port (Optional[int]): The port number of the URL. username (Optional[str]): The username for URL authentication. password (Optional[str]): The password for URL authentication. path (Optional[str]): The path of the URL. - params (Optional[Dict[str, str]]): The query parameters of the URL. + parameters (Optional[Dict[str, str]]): The query parameters of the URL. """ self._protocol = protocol self._host = host self._port = port - # address = host:port - self._address = None if not host else f"{host}:{port}" if port else host + # location -> host:port + self._location = f"{host}:{port}" if host and port else host or None self._username = username self._password = password self._path = path - self._params = params + self._parameters = parameters or {} @property - def protocol(self) -> Optional[str]: + def protocol(self) -> str: """ Gets the protocol of the URL. Returns: - Optional[str]: The protocol of the URL. + str: The protocol of the URL. """ return self._protocol @@ -81,30 +81,14 @@ def protocol(self, protocol: str) -> None: self._protocol = protocol @property - def address(self) -> Optional[str]: + def location(self) -> Optional[str]: """ - Gets the address (host:port) of the URL. + Gets the location (host:port) of the URL. Returns: - Optional[str]: The address of the URL. + Optional[str]: The location of the URL. """ - return self._address - - @address.setter - def address(self, address: str) -> None: - """ - Sets the address (host:port) of the URL. - - Args: - address (str): The address to set. - """ - self._address = address - if ":" in address: - self._host, port = address.split(":") - self._port = int(port) - else: - self._host = address - self._port = None + return self._location @property def host(self) -> Optional[str]: @@ -125,7 +109,7 @@ def host(self, host: str) -> None: host (str): The host to set. """ self._host = host - self._address = f"{host}:{self.port}" if self.port else host + self._location = f"{host}:{self.port}" if self.port else host @property def port(self) -> Optional[int]: @@ -145,8 +129,8 @@ def port(self, port: int) -> None: Args: port (int): The port to set. """ - self._port = port - self._address = f"{self.host}:{port}" if port else self.host + self._port = max(port, 0) + self._location = f"{self.host}:{port}" if port else self.host @property def username(self) -> Optional[str]: @@ -209,26 +193,26 @@ def path(self, path: str) -> None: self._path = path @property - def params(self) -> Optional[Dict[str, str]]: + def parameters(self) -> Dict[str, str]: """ Gets the query parameters of the URL. Returns: Optional[Dict[str, str]]: The query parameters of the URL. """ - return self._params + return self._parameters - @params.setter - def params(self, params: Dict[str, str]) -> None: + @parameters.setter + def parameters(self, parameters: Dict[str, str]) -> None: """ Sets the query parameters of the URL. Args: - params (Dict[str, str]): The query parameters to set. + parameters (Dict[str, str]): The query parameters to set. """ - self._params = params + self._parameters = parameters - def get_param(self, key: str) -> Optional[str]: + def get_parameter(self, key: str) -> Optional[str]: """ Gets a query parameter from the URL. @@ -238,21 +222,19 @@ def get_param(self, key: str) -> Optional[str]: Returns: str or None: The parameter value. If the parameter does not exist, returns None. """ - return self._params.get(key, None) if self._params else None + return self._parameters.get(key, None) - def add_param(self, key: str, value: str) -> None: + def add_parameter(self, key: str, value: Any) -> None: """ Adds a query parameter to the URL. Args: key (str): The parameter name. - value (str): The parameter value. + value (Any): The parameter value. """ - if not self._params: - self._params = {} - self._params[key] = value + self._parameters[key] = str(value) if value is not None else "" - def to_string(self, encode: bool = False) -> str: + def build_string(self, encode: bool = False) -> str: """ Generates the URL string based on the current components. @@ -270,15 +252,15 @@ def to_string(self, encode: bool = False) -> str: if self.password: url += f":{self.password}" url += "@" - # Set Address - url += self.address if self.address else "" + # Set location + url += self.location if self.location else "" # Set path url += "/" if self.path: url += f"{self.path}" # Set params - if self.params: - url += "?" + "&".join([f"{k}={v}" for k, v in self.params.items()]) + if self.parameters: + url += "?" + "&".join([f"{k}={v}" for k, v in self.parameters.items()]) # If the URL needs to be encoded, encode it if encode: url = parse.quote(url) @@ -291,7 +273,7 @@ def __str__(self) -> str: Returns: str: The generated URL string. """ - return self.to_string() + return self.build_string() @classmethod def value_of(cls, url: str, encoded: bool = False) -> "URL": @@ -322,7 +304,9 @@ def value_of(cls, url: str, encoded: bool = False) -> "URL": port = parsed_url.port username = parsed_url.username password = parsed_url.password - params = {k: v[0] for k, v in parse.parse_qs(parsed_url.query).items()} + parameters = {k: v[0] for k, v in parse.parse_qs(parsed_url.query).items()} path = parsed_url.path.lstrip("/") - return URL(protocol, host, port, username, password, path, params) + if not protocol: + raise ValueError("Invalid URL format: missing protocol.") + return URL(protocol, host, port, username, password, path, parameters) diff --git a/dubbo/config/__init__.py b/dubbo/config/__init__.py index bcba37a..8adaf92 100644 --- a/dubbo/config/__init__.py +++ b/dubbo/config/__init__.py @@ -13,3 +13,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .logger_config import LoggerConfig diff --git a/dubbo/config/logger_config.py b/dubbo/config/logger_config.py index 6ea97f8..ba569a0 100644 --- a/dubbo/config/logger_config.py +++ b/dubbo/config/logger_config.py @@ -15,73 +15,114 @@ # limitations under the License. import os from dataclasses import dataclass -from typing import Optional +from typing import Dict, Optional -from dubbo.logger import Level, RotateType +from dubbo.common import extension +from dubbo.common.constants import (LoggerConstants, LoggerFileRotateType, + LoggerLevel) +from dubbo.common.url import URL +from dubbo.logger import loggerFactory @dataclass class ConsoleLoggerConfig: + """Console logger configuration""" + # default is open console logger - enabled: bool = True - # default level is None, use the global level - level: Optional[Level] = None - # default formatter is None, use the global formatter - formatter: Optional[str] = None + console_enabled: bool = True + # default console formatter is None, use the global formatter + console_formatter: Optional[str] = None + + def check(self): + pass + + def dict(self) -> Dict[str, str]: + return { + LoggerConstants.LOGGER_CONSOLE_ENABLED_KEY: str(self.console_enabled), + LoggerConstants.LOGGER_CONSOLE_FORMAT_KEY: self.console_formatter or "", + } @dataclass class FileLoggerConfig: + """File logger configuration""" + # default is close file logger - enabled: bool = False - # default level is None, use the global level - level: Optional[Level] = None - # default formatter is None, use the global formatter - formatter: Optional[str] = None + file_enabled: bool = False + # default file formatter is None, use the global formatter + file_formatter: Optional[str] = None # default log file dir is user home dir - file_dir: Optional[str] = os.path.expanduser("~") + file_dir: str = os.path.expanduser("~") + # default log file name is "dubbo.log" + file_name: str = LoggerConstants.LOGGER_FILE_NAME_VALUE # default no rotate - rotate: Optional[RotateType] = RotateType.NONE + rotate: LoggerFileRotateType = LoggerFileRotateType.NONE # when rotate is SIZE, max_bytes is required, default 10M - max_bytes: Optional[int] = 1024 * 1024 * 10 - # when rotate is TIME, rotation is required, unit is day, default 1 - rotation: Optional[int] = 1 + max_bytes: int = 1024 * 1024 * 10 + # when rotate is TIME, interval is required, unit is day, default 1 + interval: int = 1 # when rotate is not NONE, backup_count is required, default 10 - backup_count: Optional[int] = 10 + backup_count: int = 10 + + def check(self) -> None: + if self.file_enabled: + if self.rotate == LoggerFileRotateType.SIZE and self.max_bytes < 0: + raise ValueError("Max bytes can't be less than 0") + elif self.rotate == LoggerFileRotateType.TIME and self.interval < 1: + raise ValueError("Interval can't be less than 1") + + def dict(self) -> Dict[str, str]: + return { + LoggerConstants.LOGGER_FILE_ENABLED_KEY: str(self.file_enabled), + LoggerConstants.LOGGER_FILE_FORMAT_KEY: self.file_formatter or "", + LoggerConstants.LOGGER_FILE_DIR_KEY: self.file_dir, + LoggerConstants.LOGGER_FILE_NAME_KEY: self.file_name, + LoggerConstants.LOGGER_FILE_ROTATE_KEY: self.rotate.value, + LoggerConstants.LOGGER_FILE_MAX_BYTES_KEY: str(self.max_bytes), + LoggerConstants.LOGGER_FILE_INTERVAL_KEY: str(self.interval), + LoggerConstants.LOGGER_FILE_BACKUP_COUNT_KEY: str(self.backup_count), + } class LoggerConfig: def __init__( self, - logger: str = "internal", - level: Level = Level.INFO, + driver: str = LoggerConstants.LOGGER_DRIVER_VALUE, + level: LoggerLevel = LoggerLevel.DEBUG, formatter: Optional[str] = None, - console_config: ConsoleLoggerConfig = ConsoleLoggerConfig(), - file_config: FileLoggerConfig = FileLoggerConfig(), + console: ConsoleLoggerConfig = ConsoleLoggerConfig(), + file: FileLoggerConfig = FileLoggerConfig(), ): - # global logger config - self._logger = logger - self._default_level = level - self._default_formatter = formatter - # console logger config - self._console_config = console_config - # file logger config - self._file_config = file_config - - self._set_default_config() - - def _set_default_config(self): - # update console logger config - if self._console_config.enabled: - if self._console_config.level is None: - self._console_config.level = self._default_level - if self._console_config.formatter is None: - self._console_config.formatter = self._default_formatter - - # update file logger config - if self._file_config.enabled: - if self._file_config.level is None: - self._file_config.level = self._default_level - if self._file_config.formatter is None: - self._file_config.formatter = self._default_formatter + # set global config + self._driver = driver + self._level = level + self._formatter = formatter + # set console config + self._console = console + self._console.check() + # set file comfig + self._file = file + self._file.check() + + def get_url(self) -> URL: + # get LoggerConfig parameters + parameters: Dict[str, str] = { + **self._console.dict(), + **self._file.dict(), + LoggerConstants.LOGGER_DRIVER_KEY: self._driver, + LoggerConstants.LOGGER_LEVEL_KEY: self._level.value, + LoggerConstants.LOGGER_FORMAT_KEY: self._formatter or "", + } + + return URL( + protocol=self._driver, + host=self._level.value, + port=None, + parameters=parameters, + ) + + def init(self): + # get logger_adapter and initialize loggerFactory + logger_adapter = extension.get_logger_adapter(self._driver, self.get_url()) + loggerFactory.logger_adapter = logger_adapter diff --git a/dubbo/imports.py b/dubbo/imports.py index 1e860c9..6d4c314 100644 --- a/dubbo/imports.py +++ b/dubbo/imports.py @@ -16,4 +16,4 @@ """Utilizing the mechanism of module loading to complete the registration of plugins.""" -import dubbo.logger.internal_logger +import dubbo.logger.internal.logger_adapter diff --git a/dubbo/logger/__init__.py b/dubbo/logger/__init__.py index 2c05a1f..f685669 100644 --- a/dubbo/logger/__init__.py +++ b/dubbo/logger/__init__.py @@ -13,4 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._logger import Level, Logger, LoggerAdapter, LoggerFactory, RotateType + +from .logger import Logger, LoggerAdapter +from .logger_factory import LoggerFactory as _LoggerFactory + +loggerFactory = _LoggerFactory() + +__all__ = ["Logger", "LoggerAdapter", "loggerFactory"] diff --git a/dubbo/logger/internal/__init__.py b/dubbo/logger/internal/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/logger/internal/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/logger/internal/logger.py b/dubbo/logger/internal/logger.py new file mode 100644 index 0000000..5e87761 --- /dev/null +++ b/dubbo/logger/internal/logger.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict + +from dubbo.common.constants import LoggerLevel +from dubbo.logger import Logger + +# The mapping from the logging level to the internal logging level. +_level_map: Dict[LoggerLevel, int] = { + LoggerLevel.DEBUG: logging.DEBUG, + LoggerLevel.INFO: logging.INFO, + LoggerLevel.WARNING: logging.WARNING, + LoggerLevel.ERROR: logging.ERROR, + LoggerLevel.CRITICAL: logging.CRITICAL, + LoggerLevel.FATAL: logging.FATAL, +} + + +class InternalLogger(Logger): + """ + The internal logger implementation. + """ + + def __init__(self, internal_logger: logging.Logger): + self._logger = internal_logger + + def _log(self, level: int, msg: str, *args, **kwargs) -> None: + # Add the stacklevel to the keyword arguments. + kwargs["stacklevel"] = kwargs.get("stacklevel", 1) + 2 + self._logger.log(level, msg, *args, **kwargs) + + def log(self, level: LoggerLevel, msg: str, *args, **kwargs) -> None: + self._log(_level_map[level], msg, *args, **kwargs) + + def debug(self, msg: str, *args, **kwargs) -> None: + self._log(logging.DEBUG, msg, *args, **kwargs) + + def info(self, msg: str, *args, **kwargs) -> None: + self._log(logging.INFO, msg, *args, **kwargs) + + def warning(self, msg: str, *args, **kwargs) -> None: + self._log(logging.WARNING, msg, *args, **kwargs) + + def error(self, msg: str, *args, **kwargs) -> None: + self._log(logging.ERROR, msg, *args, **kwargs) + + def critical(self, msg: str, *args, **kwargs) -> None: + self._log(logging.CRITICAL, msg, *args, **kwargs) + + def fatal(self, msg: str, *args, **kwargs) -> None: + self._log(logging.FATAL, msg, *args, **kwargs) + + def exception(self, msg: str, *args, **kwargs) -> None: + if kwargs.get("exc_info") is None: + kwargs["exc_info"] = True + self.error(msg, *args, **kwargs) + + def is_enabled_for(self, level: LoggerLevel) -> bool: + logging_level = _level_map.get(level) + return self._logger.isEnabledFor(logging_level) if logging_level else False diff --git a/dubbo/logger/internal/logger_adapter.py b/dubbo/logger/internal/logger_adapter.py new file mode 100644 index 0000000..4215a25 --- /dev/null +++ b/dubbo/logger/internal/logger_adapter.py @@ -0,0 +1,174 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from functools import cache +from logging import handlers + +from dubbo.common import extension +from dubbo.common.constants import (LoggerConstants, LoggerFileRotateType, + LoggerLevel) +from dubbo.common.url import URL +from dubbo.logger import Logger, LoggerAdapter +from dubbo.logger.internal.logger import InternalLogger + +"""This module provides the internal logger implementation. -> logging module""" + + +@extension.register_logger_adapter("internal") +class InternalLoggerAdapter(LoggerAdapter): + """ + Internal logger adapter. + Responsible for internal logger creation, encapsulated the logging.getLogger() method + """ + + _default_format = "%(asctime)s | %(levelname)s | %(module)s:%(funcName)s:%(lineno)d - [Dubbo] %(message)s" + + def __init__(self, config: URL): + super().__init__(config) + self._config = config + # Set level + level_name = config.parameters.get(LoggerConstants.LOGGER_LEVEL_KEY) + self._level = ( + LoggerLevel.get_level(level_name) if level_name else LoggerLevel.DEBUG + ) + self._update_level() + # Set format + self._format_str = ( + config.parameters.get(LoggerConstants.LOGGER_FORMAT_KEY) + or self._default_format + ) + + def get_logger(self, name: str) -> Logger: + """ + Create a logger instance by name. + Args: + name (str): The logger name. + Returns: + Logger: The InternalLogger instance. + """ + logger_instance = logging.getLogger(name) + # clean up handlers + for handler in logger_instance.handlers: + logger_instance.removeHandler(handler) + parameters = self._config.parameters + + # Add console handler + if parameters.get(LoggerConstants.LOGGER_CONSOLE_ENABLED_KEY) == str(True): + logger_instance.addHandler(self._get_console_handler()) + + # Add file handler + if parameters.get(LoggerConstants.LOGGER_FILE_ENABLED_KEY) == str(True): + logger_instance.addHandler(self._get_file_handler()) + + return InternalLogger(logger_instance) + + @cache + def _get_console_handler(self) -> logging.StreamHandler: + """ + Get the console handler.(Avoid duplicate consoleHandler creation with @cache) + Returns: + logging.StreamHandler: The console handler. + """ + parameters = self._config.parameters + console_handler = logging.StreamHandler() + console_format_str = ( + parameters.get(LoggerConstants.LOGGER_CONSOLE_FORMAT_KEY) + or self._format_str + ) + console_formatter = logging.Formatter(console_format_str) + console_handler.setFormatter(console_formatter) + + return console_handler + + @cache + def _get_file_handler(self) -> logging.Handler: + """ + Get the file handler.(Avoid duplicate fileHandler creation with @cache) + Returns: + logging.Handler: The file handler. + """ + parameters = self._config.parameters + # Get file path + file_dir = parameters[LoggerConstants.LOGGER_FILE_DIR_KEY] + file_name = ( + parameters[LoggerConstants.LOGGER_FILE_NAME_KEY] + or LoggerConstants.LOGGER_FILE_NAME_VALUE + ) + file_path = os.path.join(file_dir, file_name) + # Get backup count + backup_count = int( + parameters.get(LoggerConstants.LOGGER_FILE_BACKUP_COUNT_KEY) or 0 + ) + # Get rotate type + rotate_type = parameters.get(LoggerConstants.LOGGER_FILE_ROTATE_KEY) + + # Set file Handler + file_handler: logging.Handler + if rotate_type == LoggerFileRotateType.SIZE.value: + # Set RotatingFileHandler + max_bytes = int(parameters[LoggerConstants.LOGGER_FILE_MAX_BYTES_KEY]) + file_handler = handlers.RotatingFileHandler( + file_path, maxBytes=max_bytes, backupCount=backup_count + ) + elif rotate_type == LoggerFileRotateType.TIME.value: + # Set TimedRotatingFileHandler + interval = int(parameters[LoggerConstants.LOGGER_FILE_INTERVAL_KEY]) + file_handler = handlers.TimedRotatingFileHandler( + file_path, interval=interval, backupCount=backup_count + ) + else: + # Set FileHandler + file_handler = logging.FileHandler(file_path) + # Add file_handler + file_format_str = ( + parameters.get(LoggerConstants.LOGGER_FILE_FORMAT_KEY) or self._format_str + ) + file_formatter = logging.Formatter(file_format_str) + file_handler.setFormatter(file_formatter) + return file_handler + + @property + def level(self) -> LoggerLevel: + """ + Get the logging level. + Returns: + LoggerLevel: The logging level. + """ + return self._level + + @level.setter + def level(self, level: LoggerLevel) -> None: + """ + Set the logging level. + Args: + level (LoggerLevel): The logging level. + """ + if level == self._level or level is None: + return + self._level = level + self._update_level() + + def _update_level(self): + """ + Update log level. + Complete the log level change by modifying the root logger + """ + # Get the root logger + root_logger = logging.getLogger() + # Set the logging level + root_logger.setLevel(self._level.name) diff --git a/dubbo/logger/internal_logger.py b/dubbo/logger/internal_logger.py deleted file mode 100644 index 031bdc6..0000000 --- a/dubbo/logger/internal_logger.py +++ /dev/null @@ -1,148 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Dict - -from dubbo.common import extension -from dubbo.logger import Level, Logger, LoggerAdapter - -"""This module provides the internal logger implementation. -> logging module""" - -# The mapping from the logging level to the internal logging level. -_level_map: Dict[Level, int] = { - Level.DEBUG: logging.DEBUG, - Level.INFO: logging.INFO, - Level.WARNING: logging.WARNING, - Level.ERROR: logging.ERROR, - Level.CRITICAL: logging.CRITICAL, - Level.FATAL: logging.FATAL, -} - - -class InternalLogger(Logger): - """ - The internal logger implementation. - """ - - def __init__(self, internal_logger: logging.Logger): - self._logger = internal_logger - - def _log(self, level: int, msg: str, *args, **kwargs) -> None: - # Add the stacklevel to the keyword arguments. - kwargs["stacklevel"] = kwargs.get("stacklevel", 1) + 2 - self._logger.log(level, msg, *args, **kwargs) - - def log(self, level: Level, msg: str, *args, **kwargs) -> None: - self._log(_level_map[level], msg, *args, **kwargs) - - def debug(self, msg: str, *args, **kwargs) -> None: - self._log(logging.DEBUG, msg, *args, **kwargs) - - def info(self, msg: str, *args, **kwargs) -> None: - self._log(logging.INFO, msg, *args, **kwargs) - - def warning(self, msg: str, *args, **kwargs) -> None: - self._log(logging.WARNING, msg, *args, **kwargs) - - def error(self, msg: str, *args, **kwargs) -> None: - self._log(logging.ERROR, msg, *args, **kwargs) - - def critical(self, msg: str, *args, **kwargs) -> None: - self._log(logging.CRITICAL, msg, *args, **kwargs) - - def fatal(self, msg: str, *args, **kwargs) -> None: - self._log(logging.FATAL, msg, *args, **kwargs) - - def exception(self, msg: str, *args, **kwargs) -> None: - if kwargs.get("exc_info") is None: - kwargs["exc_info"] = True - self.error(msg, *args, **kwargs) - - -@extension.register_logger_adapter("internal") -class InternalLoggerAdapter(LoggerAdapter): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Set the default logging level to DEBUG. - self._level = Level.DEBUG - self._update_level(Level.DEBUG) - - def get_logger(self, name: str) -> Logger: - """ - Create a logger instance by name. - Args: - name (str): The logger name. - Returns: - Logger: The InternalLogger instance. - """ - # TODO enable config by args - logger_instance = logging.getLogger(name) - # Create a formatter - default_format = "%(asctime)s | %(levelname)s | %(module)s:%(funcName)s:%(lineno)d - [Dubbo] %(message)s" - formatter = logging.Formatter(default_format) - - # Add a console handler - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - logger_instance.addHandler(console_handler) - return InternalLogger(logger_instance) - - @property - def level(self) -> Level: - """ - Get the logging level. - Returns: - Level: The logging level. - """ - return self._level - - @level.setter - def level(self, level: Level) -> None: - """ - Set the logging level. - Args: - level (Level): The logging level. - """ - if level == self._level or level is None: - return - self._level = level - self._update_level(level) - - def _update_level(self, level: Level) -> None: - """ - Update the logging level. - """ - # Get the root logger - root_logger = logging.getLogger() - # Set the logging level - root_logger.setLevel(level.name) - - -if __name__ == "__main__": - logger_adapter = InternalLoggerAdapter() - logger = logger_adapter.get_logger("test") - logger.debug("test debug") - logger.info("test info") - logger.warning("test warning") - logger.error("test error") - logger.critical("test critical") - logger.fatal("test fatal") - try: - 1 / 0 - except ZeroDivisionError: - logger.exception("test exception") diff --git a/dubbo/logger/_logger.py b/dubbo/logger/logger.py similarity index 60% rename from dubbo/logger/_logger.py rename to dubbo/logger/logger.py index a82bb56..1cbb97f 100644 --- a/dubbo/logger/_logger.py +++ b/dubbo/logger/logger.py @@ -13,39 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import enum -import threading -from typing import Any, Dict +from typing import Any -from dubbo.common import extension - - -@enum.unique -class Level(enum.Enum): - """ - The logging level enum. - """ - - DEBUG = "DEBUG" - INFO = "INFO" - WARNING = "WARNING" - ERROR = "ERROR" - CRITICAL = "CRITICAL" - FATAL = "FATAL" - - -@enum.unique -class RotateType(enum.Enum): - """ - The file rotating type enum. - """ - - # No rotating. - NONE = "NONE" - # Rotate the file by size. - SIZE = "SIZE" - # Rotate the file by time. - TIME = "TIME" +from dubbo.common.constants import LoggerLevel +from dubbo.common.url import URL class Logger: @@ -53,12 +24,12 @@ class Logger: Logger Interface, which is used to log messages. """ - def log(self, level: Level, msg: str, *args: Any, **kwargs: Any) -> None: + def log(self, level: LoggerLevel, msg: str, *args: Any, **kwargs: Any) -> None: """ Log a message at the specified logging level. Args: - level (Level): The logging level. + level (LoggerLevel): The logging level. msg (str): The log message. *args (Any): Additional positional arguments. **kwargs (Any): Additional keyword arguments. @@ -142,19 +113,28 @@ def exception(self, msg: str, *args, **kwargs) -> None: """ raise NotImplementedError("exception() is not implemented.") + def is_enabled_for(self, level: LoggerLevel) -> bool: + """ + Is this logger enabled for level 'level'? + Args: + level (LoggerLevel): The logging level. + Return: + bool: Whether the logging level is enabled. + """ + raise ValueError("is_enabled_for() is not implemented.") + class LoggerAdapter: """ Logger Adapter Interface, which is used to support different logging libraries. """ - def __init__(self, *args, **kwargs): + def __init__(self, config: URL): """ Initialize the logger adapter. Args: - *args (Any): Additional positional arguments. - **kwargs (Any): Additional keyword arguments. + config(URL): config (URL): The config of the logger adapter. """ pass @@ -171,99 +151,21 @@ def get_logger(self, name: str) -> Logger: raise NotImplementedError("get_logger() is not implemented.") @property - def level(self) -> Level: + def level(self) -> LoggerLevel: """ Get the current logging level. Returns: - Level: The current logging level. + LoggerLevel: The current logging level. """ raise NotImplementedError("get_level() is not implemented.") @level.setter - def level(self, level: Level) -> None: + def level(self, level: LoggerLevel) -> None: """ Set the logging level. Args: - level (Level): The logging level to set. + level (LoggerLevel): The logging level to set. """ raise NotImplementedError("set_level() is not implemented.") - - -class LoggerFactory: - """ - Factory class to create loggers. - """ - - # The logger adapter. - _logger_adapter: LoggerAdapter - - # A dictionary to store all the loggers. - _loggers: Dict[str, Logger] = {} - - # A lock to protect the loggers. - _logger_lock = threading.Lock() - - @classmethod - def get_logger_adapter(cls) -> LoggerAdapter: - """ - Get the logger adapter. - - Returns: - LoggerAdapter: The current logger adapter. - """ - return cls._logger_adapter - - @classmethod - def set_logger_adapter(cls, logger_adapter: str) -> None: - """ - Set the logger adapter. - - Args: - logger_adapter (str): The name of the logger adapter to set. - """ - cls._logger_adapter = extension.get_logger_adapter(logger_adapter) - # update all loggers - cls._loggers = { - name: cls._logger_adapter.get_logger(name) for name in cls._loggers - } - - @classmethod - def get_logger(cls, name: str) -> Logger: - """ - Get the logger by name. - - Args: - name (str): The name of the logger to retrieve. - - Returns: - Logger: An instance of the requested logger. - """ - logger = cls._loggers.get(name) - if logger is None: - with cls._logger_lock: - if name not in cls._loggers: - cls._loggers[name] = cls._logger_adapter.get_logger(name) - logger = cls._loggers[name] - return logger - - @classmethod - def set_level(cls, level: Level) -> None: - """ - Set the logging level. - - Args: - level (Level): The logging level to set. - """ - cls._logger_adapter.level = level - - @classmethod - def get_level(cls) -> Level: - """ - Get the current logging level. - - Returns: - Level: The current logging level. - """ - return cls._logger_adapter.level diff --git a/dubbo/logger/logger_factory.py b/dubbo/logger/logger_factory.py new file mode 100644 index 0000000..d3545cf --- /dev/null +++ b/dubbo/logger/logger_factory.py @@ -0,0 +1,134 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading + +from dubbo.common.constants import LoggerLevel +from dubbo.logger.logger import Logger, LoggerAdapter + + +def initialize_check(func): + """ + Checks if the logger factory instance is initialized. + """ + + def wrapper(self, *args, **kwargs): + if not self._initialized: + with self._initialize_lock: + if not self._initialized: + # initialize LoggerFactory + from dubbo.config import LoggerConfig + + config = LoggerConfig() + config.init() + self._initialized = True + return func(self, *args, **kwargs) + + return wrapper + + +class LoggerFactory: + """ + Factory class to create loggers. (single object) + """ + + _instance = None + _instance_lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + if not cls._instance: + with cls._instance_lock: + if not cls._instance: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + self._logger_adapter = None + # A dictionary to store all the loggers. + self._loggers = {} + # A lock to protect the loggers. + self._loggers_lock = threading.Lock() + # Initialization flag + self._initialized = False + self._initialize_lock = threading.Lock() + + @property + @initialize_check + def logger_adapter(self) -> LoggerAdapter: + return self._logger_adapter + + @logger_adapter.setter + def logger_adapter(self, logger_adapter) -> None: + """ + Set logger config + """ + self._logger_adapter = logger_adapter + with self._loggers_lock: + # update all loggers + self._loggers = { + name: self._logger_adapter.get_logger(name) for name in self._loggers + } + self._initialized = True + + @initialize_check + def get_logger_adapter(self) -> LoggerAdapter: + """ + Get the logger adapter. + + Returns: + LoggerAdapter: The current logger adapter. + """ + return self._logger_adapter + + @initialize_check + def get_logger(self, name: str) -> Logger: + """ + Get the logger by name. + + Args: + name (str): The name of the logger to retrieve. + + Returns: + Logger: An instance of the requested logger. + """ + logger = self._loggers.get(name) + if logger is None: + with self._loggers_lock: + if name not in self._loggers: + self._loggers[name] = self._logger_adapter.get_logger(name) + logger = self._loggers[name] + return logger + + @property + @initialize_check + def level(self) -> LoggerLevel: + """ + Get the current logging level. + + Returns: + LoggerLevel: The current logging level. + """ + return self._logger_adapter.level + + @level.setter + @initialize_check + def level(self, level: LoggerLevel) -> None: + """ + Set the logging level. + + Args: + level (LoggerLevel): The logging level to set. + """ + self._logger_adapter.level = level diff --git a/tests/common/extension/test_logger_extension.py b/tests/common/extension/test_logger_extension.py index 96a50c0..b5eda81 100644 --- a/tests/common/extension/test_logger_extension.py +++ b/tests/common/extension/test_logger_extension.py @@ -15,19 +15,22 @@ # limitations under the License. import unittest +from dubbo.common import extension +from dubbo.config import LoggerConfig + class TestLoggerExtension(unittest.TestCase): def test_logger_extension(self): - import dubbo.imports - from dubbo.common import extension # Test the get_logger_adapter method. - logger_adapter = extension.get_logger_adapter("internal") + logger_adapter = extension.get_logger_adapter( + "internal", LoggerConfig().get_url() + ) # Test logger_adapter methods. logger = logger_adapter.get_logger("test") logger.debug("test debug") logger.info("test info") logger.warning("test warning") - logger.error("test error") \ No newline at end of file + logger.error("test error") diff --git a/tests/common/tets_url.py b/tests/common/tets_url.py index 40a3604..736f870 100644 --- a/tests/common/tets_url.py +++ b/tests/common/tets_url.py @@ -28,8 +28,8 @@ def test_str_to_url(self): self.assertEqual("www.facebook.com", url_0.host) self.assertEqual(None, url_0.port) self.assertEqual("friends", url_0.path) - self.assertEqual("value1", url_0.get_param("param1")) - self.assertEqual("value2", url_0.get_param("param2")) + self.assertEqual("value1", url_0.get_parameter("param1")) + self.assertEqual("value2", url_0.get_parameter("param2")) url_1 = URL.value_of("ftp://username:password@192.168.1.7:21/1/read.txt") self.assertEqual("ftp", url_1.protocol) @@ -37,7 +37,7 @@ def test_str_to_url(self): self.assertEqual("password", url_1.password) self.assertEqual("192.168.1.7", url_1.host) self.assertEqual(21, url_1.port) - self.assertEqual("192.168.1.7:21", url_1.address) + self.assertEqual("192.168.1.7:21", url_1.location) self.assertEqual("1/read.txt", url_1.path) url_2 = URL.value_of("file:///home/user1/router.js?type=script") @@ -52,8 +52,8 @@ def test_str_to_url(self): self.assertEqual("www.facebook.com", url_3.host) self.assertEqual(None, url_3.port) self.assertEqual("friends", url_3.path) - self.assertEqual("value1", url_3.get_param("param1")) - self.assertEqual("value2", url_3.get_param("param2")) + self.assertEqual("value1", url_3.get_parameter("param1")) + self.assertEqual("value2", url_3.get_parameter("param2")) def test_url_to_str(self): url_0 = URL( @@ -63,16 +63,20 @@ def test_url_to_str(self): username="username", password="password", path="path", - params={"type": "a"}, + parameters={"type": "a"}, ) self.assertEqual( - "tri://username:password@127.0.0.1:12/path?type=a", url_0.to_string() + "tri://username:password@127.0.0.1:12/path?type=a", url_0.build_string() ) url_1 = URL( - protocol="tri", host="127.0.0.1", port=12, path="path", params={"type": "a"} + protocol="tri", + host="127.0.0.1", + port=12, + path="path", + parameters={"type": "a"}, ) - self.assertEqual("tri://127.0.0.1:12/path?type=a", url_1.to_string()) + self.assertEqual("tri://127.0.0.1:12/path?type=a", url_1.build_string()) - url_2 = URL(protocol="tri", host="127.0.0.1", port=12, params={"type": "a"}) - self.assertEqual("tri://127.0.0.1:12/?type=a", url_2.to_string()) + url_2 = URL(protocol="tri", host="127.0.0.1", port=12, parameters={"type": "a"}) + self.assertEqual("tri://127.0.0.1:12/?type=a", url_2.build_string()) diff --git a/tests/logger/test_internal_logger.py b/tests/logger/test_internal_logger.py index 3f32a36..0150997 100644 --- a/tests/logger/test_internal_logger.py +++ b/tests/logger/test_internal_logger.py @@ -15,16 +15,17 @@ # limitations under the License. import unittest -from dubbo.logger import Level -from dubbo.logger.internal_logger import InternalLoggerAdapter +from dubbo.common.constants import LoggerLevel +from dubbo.config import LoggerConfig +from dubbo.logger.internal.logger_adapter import InternalLoggerAdapter class TestInternalLogger(unittest.TestCase): def test_log(self): - logger_adapter = InternalLoggerAdapter() + logger_adapter = InternalLoggerAdapter(config=LoggerConfig().get_url()) logger = logger_adapter.get_logger("test") - logger.log(Level.INFO, "test log") + logger.log(LoggerLevel.INFO, "test log") logger.debug("test debug") logger.info("test info") logger.warning("test warning") @@ -37,13 +38,11 @@ def test_log(self): logger.exception("test exception") # test different default logger level - logger_adapter.level = Level.INFO + logger_adapter.level = LoggerLevel.INFO logger.debug("debug can't be logged") - logger_adapter.level = Level.WARNING + logger_adapter.level = LoggerLevel.WARNING logger.info("info can't be logged") - logger_adapter.level = Level.ERROR + logger_adapter.level = LoggerLevel.ERROR logger.warning("warning can't be logged") - - diff --git a/tests/logger/test_logger_factory.py b/tests/logger/test_logger_factory.py new file mode 100644 index 0000000..acb68e2 --- /dev/null +++ b/tests/logger/test_logger_factory.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from dubbo.common.constants import LoggerConstants, LoggerLevel +from dubbo.config import LoggerConfig +from dubbo.logger import loggerFactory +from dubbo.logger.internal.logger_adapter import InternalLoggerAdapter + + +class TestLoggerFactory(unittest.TestCase): + + # def test_without_config(self): + # # Test the case where config is not used + # logger = loggerFactory.get_logger("test_factory") + # logger.info("info log -> without_config ") + + def test_with_config(self): + # Test the case where config is used + config = LoggerConfig() + config.init() + logger = loggerFactory.get_logger("test_factory") + logger.info("info log -> with_config ") + + logger = loggerFactory.get_logger("test_factory1") + logger.info("info log -> with_config ") + + logger = loggerFactory.get_logger("test_factory2") + logger.info("info log -> with_config ") + + url = config.get_url() + url.add_parameter(LoggerConstants.LOGGER_FILE_ENABLED_KEY, True) + loggerFactory.logger_adapter = InternalLoggerAdapter(url) + logger = loggerFactory.get_logger("test_factory") + loggerFactory.level = LoggerLevel.DEBUG + logger.debug("debug log -> with_config") From 89ae4779febfd1941822aad6fc49cb5ff50592d2 Mon Sep 17 00:00:00 2001 From: zaki Date: Sun, 16 Jun 2024 22:47:14 +0800 Subject: [PATCH 18/32] perf: update something about logger --- dubbo/_dubbo.py | 2 +- dubbo/client/__init__.py | 15 +++ dubbo/common/constants/logger_constants.py | 10 ++ dubbo/common/node.py | 44 ++++++++ dubbo/config/__init__.py | 2 +- dubbo/config/logger_config.py | 18 ++- dubbo/logger/__init__.py | 2 +- dubbo/logger/internal/logger_adapter.py | 4 +- dubbo/logger/logger.py | 2 +- dubbo/logger/logger_factory.py | 124 +++++++++------------ tests/logger/test_logger_factory.py | 26 ++--- 11 files changed, 145 insertions(+), 104 deletions(-) create mode 100644 dubbo/client/__init__.py create mode 100644 dubbo/common/node.py diff --git a/dubbo/_dubbo.py b/dubbo/_dubbo.py index 5da4bd6..4f7a73b 100644 --- a/dubbo/_dubbo.py +++ b/dubbo/_dubbo.py @@ -16,6 +16,6 @@ class Dubbo: - """The entry point of dubbo-python framework.""" + """The entry point of dubbo-python framework.(singleton)""" pass diff --git a/dubbo/client/__init__.py b/dubbo/client/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/client/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/common/constants/logger_constants.py b/dubbo/common/constants/logger_constants.py index 14ee10b..0bb9e95 100644 --- a/dubbo/common/constants/logger_constants.py +++ b/dubbo/common/constants/logger_constants.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import enum +import os from functools import cache @@ -79,4 +80,13 @@ class LoggerConstants: """some logger default value""" LOGGER_DRIVER_VALUE = "internal" + LOGGER_LEVEL_VALUE = LoggerLevel.DEBUG + # console + LOGGER_CONSOLE_ENABLED_VALUE = True + # file + LOGGER_FILE_ENABLED_VALUE = False + LOGGER_FILE_DIR_VALUE = os.path.expanduser("~") LOGGER_FILE_NAME_VALUE = "dubbo.log" + LOGGER_FILE_MAX_BYTES_VALUE = 10 * 1024 * 1024 + LOGGER_FILE_INTERVAL_VALUE = 1 + LOGGER_FILE_BACKUP_COUNT_VALUE = 10 diff --git a/dubbo/common/node.py b/dubbo/common/node.py new file mode 100644 index 0000000..71d64df --- /dev/null +++ b/dubbo/common/node.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dubbo.common.url import URL + + +class Node: + """ + Node + """ + + def get_url(self) -> URL: + """ + Get the url of the node + Returns: + URL: URL of the node + """ + raise NotImplementedError("get_url() is not implemented.") + + def is_available(self) -> bool: + """ + Check if the node is available + Returns: + bool: True if the node is available, false otherwise + """ + raise NotImplementedError("is_available() is not implemented.") + + def destroy(self) -> None: + """ + Destroy the node + """ + raise NotImplementedError("destroy() is not implemented.") diff --git a/dubbo/config/__init__.py b/dubbo/config/__init__.py index 8adaf92..b6b51a2 100644 --- a/dubbo/config/__init__.py +++ b/dubbo/config/__init__.py @@ -13,4 +13,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .logger_config import LoggerConfig +from .logger_config import ConsoleLoggerConfig, FileLoggerConfig, LoggerConfig diff --git a/dubbo/config/logger_config.py b/dubbo/config/logger_config.py index ba569a0..4ba59b8 100644 --- a/dubbo/config/logger_config.py +++ b/dubbo/config/logger_config.py @@ -13,13 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from dataclasses import dataclass from typing import Dict, Optional from dubbo.common import extension -from dubbo.common.constants import (LoggerConstants, LoggerFileRotateType, - LoggerLevel) +from dubbo.common.constants import LoggerConstants, LoggerFileRotateType, LoggerLevel from dubbo.common.url import URL from dubbo.logger import loggerFactory @@ -29,7 +27,7 @@ class ConsoleLoggerConfig: """Console logger configuration""" # default is open console logger - console_enabled: bool = True + console_enabled: bool = LoggerConstants.LOGGER_CONSOLE_ENABLED_VALUE # default console formatter is None, use the global formatter console_formatter: Optional[str] = None @@ -48,21 +46,21 @@ class FileLoggerConfig: """File logger configuration""" # default is close file logger - file_enabled: bool = False + file_enabled: bool = LoggerConstants.LOGGER_FILE_ENABLED_VALUE # default file formatter is None, use the global formatter file_formatter: Optional[str] = None # default log file dir is user home dir - file_dir: str = os.path.expanduser("~") + file_dir: str = LoggerConstants.LOGGER_FILE_DIR_VALUE # default log file name is "dubbo.log" file_name: str = LoggerConstants.LOGGER_FILE_NAME_VALUE # default no rotate rotate: LoggerFileRotateType = LoggerFileRotateType.NONE # when rotate is SIZE, max_bytes is required, default 10M - max_bytes: int = 1024 * 1024 * 10 + max_bytes: int = LoggerConstants.LOGGER_FILE_MAX_BYTES_VALUE # when rotate is TIME, interval is required, unit is day, default 1 - interval: int = 1 + interval: int = LoggerConstants.LOGGER_FILE_INTERVAL_VALUE # when rotate is not NONE, backup_count is required, default 10 - backup_count: int = 10 + backup_count: int = LoggerConstants.LOGGER_FILE_BACKUP_COUNT_VALUE def check(self) -> None: if self.file_enabled: @@ -89,7 +87,7 @@ class LoggerConfig: def __init__( self, driver: str = LoggerConstants.LOGGER_DRIVER_VALUE, - level: LoggerLevel = LoggerLevel.DEBUG, + level: LoggerLevel = LoggerConstants.LOGGER_LEVEL_VALUE, formatter: Optional[str] = None, console: ConsoleLoggerConfig = ConsoleLoggerConfig(), file: FileLoggerConfig = FileLoggerConfig(), diff --git a/dubbo/logger/__init__.py b/dubbo/logger/__init__.py index f685669..5df0681 100644 --- a/dubbo/logger/__init__.py +++ b/dubbo/logger/__init__.py @@ -17,6 +17,6 @@ from .logger import Logger, LoggerAdapter from .logger_factory import LoggerFactory as _LoggerFactory -loggerFactory = _LoggerFactory() +loggerFactory = _LoggerFactory __all__ = ["Logger", "LoggerAdapter", "loggerFactory"] diff --git a/dubbo/logger/internal/logger_adapter.py b/dubbo/logger/internal/logger_adapter.py index 4215a25..2619a9c 100644 --- a/dubbo/logger/internal/logger_adapter.py +++ b/dubbo/logger/internal/logger_adapter.py @@ -40,7 +40,6 @@ class InternalLoggerAdapter(LoggerAdapter): def __init__(self, config: URL): super().__init__(config) - self._config = config # Set level level_name = config.parameters.get(LoggerConstants.LOGGER_LEVEL_KEY) self._level = ( @@ -63,8 +62,7 @@ def get_logger(self, name: str) -> Logger: """ logger_instance = logging.getLogger(name) # clean up handlers - for handler in logger_instance.handlers: - logger_instance.removeHandler(handler) + logger_instance.handlers.clear() parameters = self._config.parameters # Add console handler diff --git a/dubbo/logger/logger.py b/dubbo/logger/logger.py index 1cbb97f..9ce3271 100644 --- a/dubbo/logger/logger.py +++ b/dubbo/logger/logger.py @@ -136,7 +136,7 @@ def __init__(self, config: URL): Args: config(URL): config (URL): The config of the logger adapter. """ - pass + self._config = config def get_logger(self, name: str) -> Logger: """ diff --git a/dubbo/logger/logger_factory.py b/dubbo/logger/logger_factory.py index d3545cf..ca79e81 100644 --- a/dubbo/logger/logger_factory.py +++ b/dubbo/logger/logger_factory.py @@ -14,86 +14,66 @@ # See the License for the specific language governing permissions and # limitations under the License. import threading - -from dubbo.common.constants import LoggerLevel -from dubbo.logger.logger import Logger, LoggerAdapter - - -def initialize_check(func): - """ - Checks if the logger factory instance is initialized. - """ - - def wrapper(self, *args, **kwargs): - if not self._initialized: - with self._initialize_lock: - if not self._initialized: - # initialize LoggerFactory - from dubbo.config import LoggerConfig - - config = LoggerConfig() - config.init() - self._initialized = True - return func(self, *args, **kwargs) - - return wrapper +from typing import Dict + +from dubbo.common.constants import LoggerConstants, LoggerLevel +from dubbo.common.url import URL +from dubbo.logger import Logger, LoggerAdapter +from dubbo.logger.internal.logger_adapter import InternalLoggerAdapter + +_default_config = URL( + protocol=LoggerConstants.LOGGER_DRIVER_VALUE, + host=LoggerConstants.LOGGER_LEVEL_VALUE.value, + port=None, + parameters={ + LoggerConstants.LOGGER_DRIVER_KEY: LoggerConstants.LOGGER_DRIVER_VALUE, + LoggerConstants.LOGGER_LEVEL_KEY: LoggerConstants.LOGGER_LEVEL_VALUE.value, + LoggerConstants.LOGGER_CONSOLE_ENABLED_KEY: str( + LoggerConstants.LOGGER_CONSOLE_ENABLED_VALUE + ), + LoggerConstants.LOGGER_FILE_ENABLED_KEY: str( + LoggerConstants.LOGGER_FILE_ENABLED_VALUE + ), + }, +) class LoggerFactory: """ - Factory class to create loggers. (single object) + Factory class to create loggers. """ - _instance = None - _instance_lock = threading.Lock() - - def __new__(cls, *args, **kwargs): - if not cls._instance: - with cls._instance_lock: - if not cls._instance: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self): - self._logger_adapter = None - # A dictionary to store all the loggers. - self._loggers = {} - # A lock to protect the loggers. - self._loggers_lock = threading.Lock() - # Initialization flag - self._initialized = False - self._initialize_lock = threading.Lock() - - @property - @initialize_check - def logger_adapter(self) -> LoggerAdapter: - return self._logger_adapter + # logger adapter + _logger_adapter = InternalLoggerAdapter(_default_config) + # A dictionary to store all the loggers. + _loggers: Dict[str, Logger] = {} + # A lock to protect the loggers. + _loggers_lock = threading.Lock() - @logger_adapter.setter - def logger_adapter(self, logger_adapter) -> None: + @classmethod + def set_logger_adapter(cls, logger_adapter) -> None: """ Set logger config """ - self._logger_adapter = logger_adapter - with self._loggers_lock: + cls._logger_adapter = logger_adapter + with cls._loggers_lock: # update all loggers - self._loggers = { - name: self._logger_adapter.get_logger(name) for name in self._loggers + cls._loggers = { + name: cls._logger_adapter.get_logger(name) for name in cls._loggers } - self._initialized = True - @initialize_check - def get_logger_adapter(self) -> LoggerAdapter: + @classmethod + def get_logger_adapter(cls) -> LoggerAdapter: """ Get the logger adapter. Returns: LoggerAdapter: The current logger adapter. """ - return self._logger_adapter + return cls._logger_adapter - @initialize_check - def get_logger(self, name: str) -> Logger: + @classmethod + def get_logger(cls, name: str) -> Logger: """ Get the logger by name. @@ -103,32 +83,30 @@ def get_logger(self, name: str) -> Logger: Returns: Logger: An instance of the requested logger. """ - logger = self._loggers.get(name) + logger = cls._loggers.get(name) if logger is None: - with self._loggers_lock: - if name not in self._loggers: - self._loggers[name] = self._logger_adapter.get_logger(name) - logger = self._loggers[name] + with cls._loggers_lock: + if name not in cls._loggers: + cls._loggers[name] = cls._logger_adapter.get_logger(name) + logger = cls._loggers[name] return logger - @property - @initialize_check - def level(self) -> LoggerLevel: + @classmethod + def get_level(cls) -> LoggerLevel: """ Get the current logging level. Returns: LoggerLevel: The current logging level. """ - return self._logger_adapter.level + return cls._logger_adapter.level - @level.setter - @initialize_check - def level(self, level: LoggerLevel) -> None: + @classmethod + def set_level(cls, level: LoggerLevel) -> None: """ Set the logging level. Args: level (LoggerLevel): The logging level to set. """ - self._logger_adapter.level = level + cls._logger_adapter.level = level diff --git a/tests/logger/test_logger_factory.py b/tests/logger/test_logger_factory.py index acb68e2..03446b1 100644 --- a/tests/logger/test_logger_factory.py +++ b/tests/logger/test_logger_factory.py @@ -23,10 +23,10 @@ class TestLoggerFactory(unittest.TestCase): - # def test_without_config(self): - # # Test the case where config is not used - # logger = loggerFactory.get_logger("test_factory") - # logger.info("info log -> without_config ") + def test_without_config(self): + # Test the case where config is not used + logger = loggerFactory.get_logger("test_factory") + logger.info("info log -> without_config ") def test_with_config(self): # Test the case where config is used @@ -35,15 +35,13 @@ def test_with_config(self): logger = loggerFactory.get_logger("test_factory") logger.info("info log -> with_config ") - logger = loggerFactory.get_logger("test_factory1") - logger.info("info log -> with_config ") - - logger = loggerFactory.get_logger("test_factory2") - logger.info("info log -> with_config ") - url = config.get_url() url.add_parameter(LoggerConstants.LOGGER_FILE_ENABLED_KEY, True) - loggerFactory.logger_adapter = InternalLoggerAdapter(url) - logger = loggerFactory.get_logger("test_factory") - loggerFactory.level = LoggerLevel.DEBUG - logger.debug("debug log -> with_config") + loggerFactory.set_logger_adapter(InternalLoggerAdapter(url)) + loggerFactory.set_level(LoggerLevel.DEBUG) + logger.debug("debug log -> with_config -> open file") + + url.add_parameter(LoggerConstants.LOGGER_CONSOLE_ENABLED_KEY, False) + loggerFactory.set_logger_adapter(InternalLoggerAdapter(url)) + loggerFactory.set_level(LoggerLevel.DEBUG) + logger.debug("debug log -> with_config -> lose console") From 1e739774edf12a224cf182561fa6d02dc980fcd6 Mon Sep 17 00:00:00 2001 From: zaki Date: Mon, 17 Jun 2024 11:59:16 +0800 Subject: [PATCH 19/32] style: Becoming more regulated --- dubbo/client/client.py | 23 +++ dubbo/common/constants/__init__.py | 2 - .../{logger_constants.py => logger.py} | 67 ++++---- dubbo/common/url.py | 78 +++++----- dubbo/config/logger_config.py | 143 +++++++++++------- dubbo/logger/internal/logger.py | 24 +-- dubbo/logger/internal/logger_adapter.py | 77 +++++----- dubbo/logger/logger.py | 22 +-- dubbo/logger/logger_factory.py | 47 +++--- tests/common/tets_url.py | 4 +- tests/logger/test_internal_logger.py | 10 +- tests/logger/test_logger_factory.py | 13 +- 12 files changed, 291 insertions(+), 219 deletions(-) create mode 100644 dubbo/client/client.py rename dubbo/common/constants/{logger_constants.py => logger.py} (51%) diff --git a/dubbo/client/client.py b/dubbo/client/client.py new file mode 100644 index 0000000..e4eaefd --- /dev/null +++ b/dubbo/client/client.py @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dubbo.logger import loggerFactory + +logger = loggerFactory.get_logger(__name__) + + +class Client: + + pass diff --git a/dubbo/common/constants/__init__.py b/dubbo/common/constants/__init__.py index 44dc90e..bcba37a 100644 --- a/dubbo/common/constants/__init__.py +++ b/dubbo/common/constants/__init__.py @@ -13,5 +13,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .logger_constants import (LoggerConstants, LoggerFileRotateType, - LoggerLevel) diff --git a/dubbo/common/constants/logger_constants.py b/dubbo/common/constants/logger.py similarity index 51% rename from dubbo/common/constants/logger_constants.py rename to dubbo/common/constants/logger.py index 0bb9e95..b68cab8 100644 --- a/dubbo/common/constants/logger_constants.py +++ b/dubbo/common/constants/logger.py @@ -19,7 +19,7 @@ @enum.unique -class LoggerLevel(enum.Enum): +class Level(enum.Enum): """ The logging level enum. """ @@ -33,7 +33,7 @@ class LoggerLevel(enum.Enum): @classmethod @cache - def get_level(cls, level_value: str) -> "LoggerLevel": + def get_level(cls, level_value: str) -> "Level": level_value = level_value.upper() for level in cls: if level_value == level.value: @@ -42,7 +42,7 @@ def get_level(cls, level_value: str) -> "LoggerLevel": @enum.unique -class LoggerFileRotateType(enum.Enum): +class FileRotateType(enum.Enum): """ The file rotating type enum. """ @@ -55,38 +55,35 @@ class LoggerFileRotateType(enum.Enum): TIME = "TIME" -class LoggerConstants: - """logger configuration constants.""" +"""logger config keys""" +# global config +LEVEL_KEY = "logger.level" +DRIVER_KEY = "logger.driver" +FORMAT_KEY = "logger.format" - """logger config keys""" - # global config - LOGGER_LEVEL_KEY = "logger.level" - LOGGER_DRIVER_KEY = "logger.driver" - LOGGER_FORMAT_KEY = "logger.format" +# console config +CONSOLE_ENABLED_KEY = "logger.console.enable" +CONSOLE_FORMAT_KEY = "logger.console.format" - # console config - LOGGER_CONSOLE_ENABLED_KEY = "logger.console.enable" - LOGGER_CONSOLE_FORMAT_KEY = "logger.console.format" +# file logger +FILE_ENABLED_KEY = "logger.file.enable" +FILE_FORMAT_KEY = "logger.file.format" +FILE_DIR_KEY = "logger.file.dir" +FILE_NAME_KEY = "logger.file.name" +FILE_ROTATE_KEY = "logger.file.rotate" +FILE_MAX_BYTES_KEY = "logger.file.maxbytes" +FILE_INTERVAL_KEY = "logger.file.interval" +FILE_BACKUP_COUNT_KEY = "logger.file.backupcount" - # file logger - LOGGER_FILE_ENABLED_KEY = "logger.file.enable" - LOGGER_FILE_FORMAT_KEY = "logger.file.format" - LOGGER_FILE_DIR_KEY = "logger.file.dir" - LOGGER_FILE_NAME_KEY = "logger.file.name" - LOGGER_FILE_ROTATE_KEY = "logger.file.rotate" - LOGGER_FILE_MAX_BYTES_KEY = "logger.file.maxbytes" - LOGGER_FILE_INTERVAL_KEY = "logger.file.interval" - LOGGER_FILE_BACKUP_COUNT_KEY = "logger.file.backupcount" - - """some logger default value""" - LOGGER_DRIVER_VALUE = "internal" - LOGGER_LEVEL_VALUE = LoggerLevel.DEBUG - # console - LOGGER_CONSOLE_ENABLED_VALUE = True - # file - LOGGER_FILE_ENABLED_VALUE = False - LOGGER_FILE_DIR_VALUE = os.path.expanduser("~") - LOGGER_FILE_NAME_VALUE = "dubbo.log" - LOGGER_FILE_MAX_BYTES_VALUE = 10 * 1024 * 1024 - LOGGER_FILE_INTERVAL_VALUE = 1 - LOGGER_FILE_BACKUP_COUNT_VALUE = 10 +"""some logger default value""" +DEFAULT_DRIVER_VALUE = "logging" +DEFAULT_LEVEL_VALUE = Level.DEBUG +# console +DEFAULT_CONSOLE_ENABLED_VALUE = True +# file +DEFAULT_FILE_ENABLED_VALUE = False +DEFAULT_FILE_DIR_VALUE = os.path.expanduser("~") +DEFAULT_FILE_NAME_VALUE = "dubbo.log" +DEFAULT_FILE_MAX_BYTES_VALUE = 10 * 1024 * 1024 +DEFAULT_FILE_INTERVAL_VALUE = 1 +DEFAULT_FILE_BACKUP_COUNT_VALUE = 10 diff --git a/dubbo/common/url.py b/dubbo/common/url.py index bb78f49..64dcf4c 100644 --- a/dubbo/common/url.py +++ b/dubbo/common/url.py @@ -19,7 +19,15 @@ class URL: """ - URL - Uniform Resource Locator + URL - Uniform Resource Locator. + Attributes: + _protocol (str): The protocol of the URL. + _host (str): The host of the URL. + _port (int): The port number of the URL. + _username (str): The username for URL authentication. + _password (str): The password for URL authentication. + _path (str): The path of the URL. + _parameters (Dict[str, str]): The query parameters of the URL. url example: - http://www.facebook.com/friends?param1=value1¶m2=value2 @@ -28,33 +36,29 @@ class URL: - registry://192.168.1.7:9090/org.apache.dubbo.service1?param1=value1¶m2=value2 """ + _protocol: str + _username: str + _password: str + _host: str + _port: int + _path: str + _parameters: Dict[str, str] + def __init__( self, protocol: str, - host: Optional[str], - port: Optional[int], - username: Optional[str] = None, - password: Optional[str] = None, - path: Optional[str] = None, + host: str, + port: int = 0, + username: str = "", + password: str = "", + path: str = "", parameters: Optional[Dict[str, str]] = None, ): - """ - Initializes the URL with the given components. - - Args: - protocol (str): The protocol of the URL. - host (Optional[str]): The host of the URL. - port (Optional[int]): The port number of the URL. - username (Optional[str]): The username for URL authentication. - password (Optional[str]): The password for URL authentication. - path (Optional[str]): The path of the URL. - parameters (Optional[Dict[str, str]]): The query parameters of the URL. - """ self._protocol = protocol self._host = host self._port = port # location -> host:port - self._location = f"{host}:{port}" if host and port else host or None + self._location = f"{host}:{port}" if port > 0 else host self._username = username self._password = password self._path = path @@ -81,22 +85,22 @@ def protocol(self, protocol: str) -> None: self._protocol = protocol @property - def location(self) -> Optional[str]: + def location(self) -> str: """ Gets the location (host:port) of the URL. Returns: - Optional[str]: The location of the URL. + str: The location of the URL. """ return self._location @property - def host(self) -> Optional[str]: + def host(self) -> str: """ Gets the host of the URL. Returns: - Optional[str]: The host of the URL. + str: The host of the URL. """ return self._host @@ -112,12 +116,12 @@ def host(self, host: str) -> None: self._location = f"{host}:{self.port}" if self.port else host @property - def port(self) -> Optional[int]: + def port(self) -> int: """ Gets the port of the URL. Returns: - Optional[int]: The port of the URL. + int: The port of the URL. """ return self._port @@ -133,12 +137,12 @@ def port(self, port: int) -> None: self._location = f"{self.host}:{port}" if port else self.host @property - def username(self) -> Optional[str]: + def username(self) -> str: """ Gets the username for URL authentication. Returns: - Optional[str]: The username for URL authentication. + str: The username for URL authentication. """ return self._username @@ -153,12 +157,12 @@ def username(self, username: str) -> None: self._username = username @property - def password(self) -> Optional[str]: + def password(self) -> str: """ Gets the password for URL authentication. Returns: - Optional[str]: The password for URL authentication. + [str]: The password for URL authentication. """ return self._password @@ -173,12 +177,12 @@ def password(self, password: str) -> None: self._password = password @property - def path(self) -> Optional[str]: + def path(self) -> str: """ Gets the path of the URL. Returns: - Optional[str]: The path of the URL. + str: The path of the URL. """ return self._path @@ -198,7 +202,7 @@ def parameters(self) -> Dict[str, str]: Gets the query parameters of the URL. Returns: - Optional[Dict[str, str]]: The query parameters of the URL. + Dict[str, str]: The query parameters of the URL. """ return self._parameters @@ -217,7 +221,7 @@ def get_parameter(self, key: str) -> Optional[str]: Gets a query parameter from the URL. Args: - key (str): The parameter name. + key (Optional[str]): The parameter name. Returns: str or None: The parameter value. If the parameter does not exist, returns None. @@ -300,10 +304,10 @@ def value_of(cls, url: str, encoded: bool = False) -> "URL": parsed_url = parse.urlparse(url) protocol = parsed_url.scheme - host = parsed_url.hostname - port = parsed_url.port - username = parsed_url.username - password = parsed_url.password + host = parsed_url.hostname or "" + port = parsed_url.port or 0 + username = parsed_url.username or "" + password = parsed_url.password or "" parameters = {k: v[0] for k, v in parse.parse_qs(parsed_url.query).items()} path = parsed_url.path.lstrip("/") diff --git a/dubbo/config/logger_config.py b/dubbo/config/logger_config.py index 4ba59b8..43035b8 100644 --- a/dubbo/config/logger_config.py +++ b/dubbo/config/logger_config.py @@ -17,108 +17,135 @@ from typing import Dict, Optional from dubbo.common import extension -from dubbo.common.constants import LoggerConstants, LoggerFileRotateType, LoggerLevel +from dubbo.common.constants import logger as logger_constants +from dubbo.common.constants.logger import FileRotateType, Level from dubbo.common.url import URL from dubbo.logger import loggerFactory @dataclass class ConsoleLoggerConfig: - """Console logger configuration""" + """ + Console logger configuration. + Attributes: + console_format(Optional[str]): console format, if null, use global format. + """ - # default is open console logger - console_enabled: bool = LoggerConstants.LOGGER_CONSOLE_ENABLED_VALUE - # default console formatter is None, use the global formatter - console_formatter: Optional[str] = None + console_format: Optional[str] = None def check(self): pass def dict(self) -> Dict[str, str]: return { - LoggerConstants.LOGGER_CONSOLE_ENABLED_KEY: str(self.console_enabled), - LoggerConstants.LOGGER_CONSOLE_FORMAT_KEY: self.console_formatter or "", + logger_constants.CONSOLE_FORMAT_KEY: self.console_format or "", } @dataclass class FileLoggerConfig: - """File logger configuration""" - - # default is close file logger - file_enabled: bool = LoggerConstants.LOGGER_FILE_ENABLED_VALUE - # default file formatter is None, use the global formatter + """ + File logger configuration. + Attributes: + rotate(FileRotateType): File rotate type. Optional: NONE,SIZE,TIME. Default: NONE. + file_formatter(Optional[str]): file format, if null, use global format. + file_dir(str): file directory. Default: user home dir + file_name(str): file name. Default: dubbo.log + backup_count(int): backup count. Default: 10 (when rotate is not NONE, backup_count is required) + max_bytes(int): maximum file size. Default: 1024.(when rotate is SIZE, max_bytes is required) + interval(int): interval time in seconds. Default: 1.(when rotate is TIME, interval is required, unit is day) + + """ + + rotate: FileRotateType = FileRotateType.NONE file_formatter: Optional[str] = None - # default log file dir is user home dir - file_dir: str = LoggerConstants.LOGGER_FILE_DIR_VALUE - # default log file name is "dubbo.log" - file_name: str = LoggerConstants.LOGGER_FILE_NAME_VALUE - # default no rotate - rotate: LoggerFileRotateType = LoggerFileRotateType.NONE - # when rotate is SIZE, max_bytes is required, default 10M - max_bytes: int = LoggerConstants.LOGGER_FILE_MAX_BYTES_VALUE - # when rotate is TIME, interval is required, unit is day, default 1 - interval: int = LoggerConstants.LOGGER_FILE_INTERVAL_VALUE - # when rotate is not NONE, backup_count is required, default 10 - backup_count: int = LoggerConstants.LOGGER_FILE_BACKUP_COUNT_VALUE + file_dir: str = logger_constants.DEFAULT_FILE_DIR_VALUE + file_name: str = logger_constants.DEFAULT_FILE_NAME_VALUE + backup_count: int = logger_constants.DEFAULT_FILE_BACKUP_COUNT_VALUE + max_bytes: int = logger_constants.DEFAULT_FILE_MAX_BYTES_VALUE + interval: int = logger_constants.DEFAULT_FILE_INTERVAL_VALUE def check(self) -> None: - if self.file_enabled: - if self.rotate == LoggerFileRotateType.SIZE and self.max_bytes < 0: - raise ValueError("Max bytes can't be less than 0") - elif self.rotate == LoggerFileRotateType.TIME and self.interval < 1: - raise ValueError("Interval can't be less than 1") + if self.rotate == FileRotateType.SIZE and self.max_bytes < 0: + raise ValueError("Max bytes can't be less than 0") + elif self.rotate == FileRotateType.TIME and self.interval < 1: + raise ValueError("Interval can't be less than 1") def dict(self) -> Dict[str, str]: return { - LoggerConstants.LOGGER_FILE_ENABLED_KEY: str(self.file_enabled), - LoggerConstants.LOGGER_FILE_FORMAT_KEY: self.file_formatter or "", - LoggerConstants.LOGGER_FILE_DIR_KEY: self.file_dir, - LoggerConstants.LOGGER_FILE_NAME_KEY: self.file_name, - LoggerConstants.LOGGER_FILE_ROTATE_KEY: self.rotate.value, - LoggerConstants.LOGGER_FILE_MAX_BYTES_KEY: str(self.max_bytes), - LoggerConstants.LOGGER_FILE_INTERVAL_KEY: str(self.interval), - LoggerConstants.LOGGER_FILE_BACKUP_COUNT_KEY: str(self.backup_count), + logger_constants.FILE_FORMAT_KEY: self.file_formatter or "", + logger_constants.FILE_DIR_KEY: self.file_dir, + logger_constants.FILE_NAME_KEY: self.file_name, + logger_constants.FILE_ROTATE_KEY: self.rotate.value, + logger_constants.FILE_MAX_BYTES_KEY: str(self.max_bytes), + logger_constants.FILE_INTERVAL_KEY: str(self.interval), + logger_constants.FILE_BACKUP_COUNT_KEY: str(self.backup_count), } class LoggerConfig: + """ + Logger configuration. + + Attributes: + _driver(str): logger driver type. + _level(Level): logger level. + _formatter(Optional[str]): logger formatter. + _console_enabled(bool): logger console enabled. + _console_config(ConsoleLoggerConfig): logger console config. + _file_enabled(bool): logger file enabled. + _file_config(FileLoggerConfig): logger file config. + """ + + # global + _driver: str + _level: Level + _formatter: Optional[str] + # console + _console_enabled: bool + _console_config: ConsoleLoggerConfig + # file + _file_enabled: bool + _file_config: FileLoggerConfig def __init__( self, - driver: str = LoggerConstants.LOGGER_DRIVER_VALUE, - level: LoggerLevel = LoggerConstants.LOGGER_LEVEL_VALUE, + driver, + level=logger_constants.DEFAULT_LEVEL_VALUE, formatter: Optional[str] = None, - console: ConsoleLoggerConfig = ConsoleLoggerConfig(), - file: FileLoggerConfig = FileLoggerConfig(), + console_enabled: bool = logger_constants.DEFAULT_CONSOLE_ENABLED_VALUE, + console_config: ConsoleLoggerConfig = ConsoleLoggerConfig(), + file_enabled: bool = logger_constants.DEFAULT_FILE_ENABLED_VALUE, + file_config: FileLoggerConfig = FileLoggerConfig(), ): # set global config self._driver = driver self._level = level self._formatter = formatter # set console config - self._console = console - self._console.check() + self._console_enabled = console_enabled + self._console_config = console_config + if console_enabled: + self._console_config.check() # set file comfig - self._file = file - self._file.check() + self._file_enabled = file_enabled + self._file_config = file_config + if file_enabled: + self._file_config.check() def get_url(self) -> URL: # get LoggerConfig parameters - parameters: Dict[str, str] = { - **self._console.dict(), - **self._file.dict(), - LoggerConstants.LOGGER_DRIVER_KEY: self._driver, - LoggerConstants.LOGGER_LEVEL_KEY: self._level.value, - LoggerConstants.LOGGER_FORMAT_KEY: self._formatter or "", + parameters = { + logger_constants.DRIVER_KEY: self._driver, + logger_constants.LEVEL_KEY: self._level.value, + logger_constants.FORMAT_KEY: self._formatter or "", + logger_constants.CONSOLE_ENABLED_KEY: str(self._console_enabled), + logger_constants.FILE_ENABLED_KEY: str(self._file_enabled), + **self._console_config.dict(), + **self._file_config.dict(), } - return URL( - protocol=self._driver, - host=self._level.value, - port=None, - parameters=parameters, - ) + return URL(protocol=self._driver, host=self._level.value, parameters=parameters) def init(self): # get logger_adapter and initialize loggerFactory diff --git a/dubbo/logger/internal/logger.py b/dubbo/logger/internal/logger.py index 5e87761..6e84a35 100644 --- a/dubbo/logger/internal/logger.py +++ b/dubbo/logger/internal/logger.py @@ -17,25 +17,29 @@ import logging from typing import Dict -from dubbo.common.constants import LoggerLevel +from dubbo.common.constants.logger import Level from dubbo.logger import Logger # The mapping from the logging level to the internal logging level. -_level_map: Dict[LoggerLevel, int] = { - LoggerLevel.DEBUG: logging.DEBUG, - LoggerLevel.INFO: logging.INFO, - LoggerLevel.WARNING: logging.WARNING, - LoggerLevel.ERROR: logging.ERROR, - LoggerLevel.CRITICAL: logging.CRITICAL, - LoggerLevel.FATAL: logging.FATAL, +_level_map: Dict[Level, int] = { + Level.DEBUG: logging.DEBUG, + Level.INFO: logging.INFO, + Level.WARNING: logging.WARNING, + Level.ERROR: logging.ERROR, + Level.CRITICAL: logging.CRITICAL, + Level.FATAL: logging.FATAL, } class InternalLogger(Logger): """ The internal logger implementation. + Attributes: + _logger (logging.Logger): The real working logger object """ + _logger: logging.Logger + def __init__(self, internal_logger: logging.Logger): self._logger = internal_logger @@ -44,7 +48,7 @@ def _log(self, level: int, msg: str, *args, **kwargs) -> None: kwargs["stacklevel"] = kwargs.get("stacklevel", 1) + 2 self._logger.log(level, msg, *args, **kwargs) - def log(self, level: LoggerLevel, msg: str, *args, **kwargs) -> None: + def log(self, level: Level, msg: str, *args, **kwargs) -> None: self._log(_level_map[level], msg, *args, **kwargs) def debug(self, msg: str, *args, **kwargs) -> None: @@ -70,6 +74,6 @@ def exception(self, msg: str, *args, **kwargs) -> None: kwargs["exc_info"] = True self.error(msg, *args, **kwargs) - def is_enabled_for(self, level: LoggerLevel) -> bool: + def is_enabled_for(self, level: Level) -> bool: logging_level = _level_map.get(level) return self._logger.isEnabledFor(logging_level) if logging_level else False diff --git a/dubbo/logger/internal/logger_adapter.py b/dubbo/logger/internal/logger_adapter.py index 2619a9c..b4ba560 100644 --- a/dubbo/logger/internal/logger_adapter.py +++ b/dubbo/logger/internal/logger_adapter.py @@ -20,36 +20,38 @@ from logging import handlers from dubbo.common import extension -from dubbo.common.constants import (LoggerConstants, LoggerFileRotateType, - LoggerLevel) +from dubbo.common.constants import logger as logger_constants +from dubbo.common.constants.logger import FileRotateType, Level from dubbo.common.url import URL from dubbo.logger import Logger, LoggerAdapter from dubbo.logger.internal.logger import InternalLogger """This module provides the internal logger implementation. -> logging module""" +_default_format = "%(asctime)s | %(levelname)s | %(module)s:%(funcName)s:%(lineno)d - [Dubbo] %(message)s" -@extension.register_logger_adapter("internal") + +@extension.register_logger_adapter("logging") class InternalLoggerAdapter(LoggerAdapter): """ - Internal logger adapter. - Responsible for internal logger creation, encapsulated the logging.getLogger() method + Internal logger adapter.Responsible for internal logger creation, encapsulated the logging.getLogger() method + Attributes: + _level(Level): logging level. + _format(str): default logging format string. """ - _default_format = "%(asctime)s | %(levelname)s | %(module)s:%(funcName)s:%(lineno)d - [Dubbo] %(message)s" + _level: Level + _format: str def __init__(self, config: URL): super().__init__(config) # Set level - level_name = config.parameters.get(LoggerConstants.LOGGER_LEVEL_KEY) - self._level = ( - LoggerLevel.get_level(level_name) if level_name else LoggerLevel.DEBUG - ) + level_name = config.parameters.get(logger_constants.LEVEL_KEY) + self._level = Level.get_level(level_name) if level_name else Level.DEBUG self._update_level() # Set format - self._format_str = ( - config.parameters.get(LoggerConstants.LOGGER_FORMAT_KEY) - or self._default_format + self._format = ( + config.parameters.get(logger_constants.FORMAT_KEY) or _default_format ) def get_logger(self, name: str) -> Logger: @@ -66,13 +68,17 @@ def get_logger(self, name: str) -> Logger: parameters = self._config.parameters # Add console handler - if parameters.get(LoggerConstants.LOGGER_CONSOLE_ENABLED_KEY) == str(True): + if parameters.get(logger_constants.CONSOLE_ENABLED_KEY) == str(True): logger_instance.addHandler(self._get_console_handler()) # Add file handler - if parameters.get(LoggerConstants.LOGGER_FILE_ENABLED_KEY) == str(True): + if parameters.get(logger_constants.FILE_ENABLED_KEY) == str(True): logger_instance.addHandler(self._get_file_handler()) + if not logger_instance.handlers: + # It's intended to be used to avoid the "No handlers could be found for logger XXX" one-off warning. + logger_instance.addHandler(logging.NullHandler()) + return InternalLogger(logger_instance) @cache @@ -84,11 +90,10 @@ def _get_console_handler(self) -> logging.StreamHandler: """ parameters = self._config.parameters console_handler = logging.StreamHandler() - console_format_str = ( - parameters.get(LoggerConstants.LOGGER_CONSOLE_FORMAT_KEY) - or self._format_str + console_format = ( + parameters.get(logger_constants.CONSOLE_FORMAT_KEY) or self._format ) - console_formatter = logging.Formatter(console_format_str) + console_formatter = logging.Formatter(console_format) console_handler.setFormatter(console_formatter) return console_handler @@ -102,59 +107,59 @@ def _get_file_handler(self) -> logging.Handler: """ parameters = self._config.parameters # Get file path - file_dir = parameters[LoggerConstants.LOGGER_FILE_DIR_KEY] + file_dir = parameters[logger_constants.FILE_DIR_KEY] file_name = ( - parameters[LoggerConstants.LOGGER_FILE_NAME_KEY] - or LoggerConstants.LOGGER_FILE_NAME_VALUE + parameters[logger_constants.FILE_NAME_KEY] + or logger_constants.DEFAULT_FILE_NAME_VALUE ) file_path = os.path.join(file_dir, file_name) # Get backup count backup_count = int( - parameters.get(LoggerConstants.LOGGER_FILE_BACKUP_COUNT_KEY) or 0 + parameters.get(logger_constants.FILE_BACKUP_COUNT_KEY) + or logger_constants.DEFAULT_FILE_BACKUP_COUNT_VALUE ) # Get rotate type - rotate_type = parameters.get(LoggerConstants.LOGGER_FILE_ROTATE_KEY) + rotate_type = parameters.get(logger_constants.FILE_ROTATE_KEY) # Set file Handler file_handler: logging.Handler - if rotate_type == LoggerFileRotateType.SIZE.value: + if rotate_type == FileRotateType.SIZE.value: # Set RotatingFileHandler - max_bytes = int(parameters[LoggerConstants.LOGGER_FILE_MAX_BYTES_KEY]) + max_bytes = int(parameters[logger_constants.FILE_MAX_BYTES_KEY]) file_handler = handlers.RotatingFileHandler( file_path, maxBytes=max_bytes, backupCount=backup_count ) - elif rotate_type == LoggerFileRotateType.TIME.value: + elif rotate_type == FileRotateType.TIME.value: # Set TimedRotatingFileHandler - interval = int(parameters[LoggerConstants.LOGGER_FILE_INTERVAL_KEY]) + interval = int(parameters[logger_constants.FILE_INTERVAL_KEY]) file_handler = handlers.TimedRotatingFileHandler( file_path, interval=interval, backupCount=backup_count ) else: # Set FileHandler file_handler = logging.FileHandler(file_path) + # Add file_handler - file_format_str = ( - parameters.get(LoggerConstants.LOGGER_FILE_FORMAT_KEY) or self._format_str - ) - file_formatter = logging.Formatter(file_format_str) + file_format = parameters.get(logger_constants.FILE_FORMAT_KEY) or self._format + file_formatter = logging.Formatter(file_format) file_handler.setFormatter(file_formatter) return file_handler @property - def level(self) -> LoggerLevel: + def level(self) -> Level: """ Get the logging level. Returns: - LoggerLevel: The logging level. + Level: The logging level. """ return self._level @level.setter - def level(self, level: LoggerLevel) -> None: + def level(self, level: Level) -> None: """ Set the logging level. Args: - level (LoggerLevel): The logging level. + level (Level): The logging level. """ if level == self._level or level is None: return diff --git a/dubbo/logger/logger.py b/dubbo/logger/logger.py index 9ce3271..a0c7460 100644 --- a/dubbo/logger/logger.py +++ b/dubbo/logger/logger.py @@ -15,7 +15,7 @@ # limitations under the License. from typing import Any -from dubbo.common.constants import LoggerLevel +from dubbo.common.constants.logger import Level from dubbo.common.url import URL @@ -24,12 +24,12 @@ class Logger: Logger Interface, which is used to log messages. """ - def log(self, level: LoggerLevel, msg: str, *args: Any, **kwargs: Any) -> None: + def log(self, level: Level, msg: str, *args: Any, **kwargs: Any) -> None: """ Log a message at the specified logging level. Args: - level (LoggerLevel): The logging level. + level (Level): The logging level. msg (str): The log message. *args (Any): Additional positional arguments. **kwargs (Any): Additional keyword arguments. @@ -113,11 +113,11 @@ def exception(self, msg: str, *args, **kwargs) -> None: """ raise NotImplementedError("exception() is not implemented.") - def is_enabled_for(self, level: LoggerLevel) -> bool: + def is_enabled_for(self, level: Level) -> bool: """ Is this logger enabled for level 'level'? Args: - level (LoggerLevel): The logging level. + level (Level): The logging level. Return: bool: Whether the logging level is enabled. """ @@ -127,8 +127,12 @@ def is_enabled_for(self, level: LoggerLevel) -> bool: class LoggerAdapter: """ Logger Adapter Interface, which is used to support different logging libraries. + Attributes: + _config(URL): logger adapter configuration. """ + _config: URL + def __init__(self, config: URL): """ Initialize the logger adapter. @@ -151,21 +155,21 @@ def get_logger(self, name: str) -> Logger: raise NotImplementedError("get_logger() is not implemented.") @property - def level(self) -> LoggerLevel: + def level(self) -> Level: """ Get the current logging level. Returns: - LoggerLevel: The current logging level. + Level: The current logging level. """ raise NotImplementedError("get_level() is not implemented.") @level.setter - def level(self, level: LoggerLevel) -> None: + def level(self, level: Level) -> None: """ Set the logging level. Args: - level (LoggerLevel): The logging level to set. + level (Level): The logging level to set. """ raise NotImplementedError("set_level() is not implemented.") diff --git a/dubbo/logger/logger_factory.py b/dubbo/logger/logger_factory.py index ca79e81..4b594ab 100644 --- a/dubbo/logger/logger_factory.py +++ b/dubbo/logger/logger_factory.py @@ -16,23 +16,24 @@ import threading from typing import Dict -from dubbo.common.constants import LoggerConstants, LoggerLevel +from dubbo.common.constants import logger as logger_constants +from dubbo.common.constants.logger import Level from dubbo.common.url import URL from dubbo.logger import Logger, LoggerAdapter from dubbo.logger.internal.logger_adapter import InternalLoggerAdapter +# Default config of InternalLoggerAdapter _default_config = URL( - protocol=LoggerConstants.LOGGER_DRIVER_VALUE, - host=LoggerConstants.LOGGER_LEVEL_VALUE.value, - port=None, + protocol=logger_constants.DEFAULT_DRIVER_VALUE, + host=logger_constants.DEFAULT_LEVEL_VALUE.value, parameters={ - LoggerConstants.LOGGER_DRIVER_KEY: LoggerConstants.LOGGER_DRIVER_VALUE, - LoggerConstants.LOGGER_LEVEL_KEY: LoggerConstants.LOGGER_LEVEL_VALUE.value, - LoggerConstants.LOGGER_CONSOLE_ENABLED_KEY: str( - LoggerConstants.LOGGER_CONSOLE_ENABLED_VALUE + logger_constants.DRIVER_KEY: logger_constants.DEFAULT_DRIVER_VALUE, + logger_constants.LEVEL_KEY: logger_constants.DEFAULT_LEVEL_VALUE.value, + logger_constants.CONSOLE_ENABLED_KEY: str( + logger_constants.DEFAULT_CONSOLE_ENABLED_VALUE ), - LoggerConstants.LOGGER_FILE_ENABLED_KEY: str( - LoggerConstants.LOGGER_FILE_ENABLED_VALUE + logger_constants.FILE_ENABLED_KEY: str( + logger_constants.DEFAULT_FILE_ENABLED_VALUE ), }, ) @@ -41,13 +42,14 @@ class LoggerFactory: """ Factory class to create loggers. + Attributes: + _logger_adapter(LoggerAdapter): logger adapter. Default: InternalLoggerAdapter(_default_config) + _loggers(Dict[str, LoggerAdapter]): A dictionary to store all the loggers. + _loggers_lock(threading.Lock): The lock is used to lock all loggers when the logger adapter is changed. """ - # logger adapter _logger_adapter = InternalLoggerAdapter(_default_config) - # A dictionary to store all the loggers. _loggers: Dict[str, Logger] = {} - # A lock to protect the loggers. _loggers_lock = threading.Lock() @classmethod @@ -56,11 +58,14 @@ def set_logger_adapter(cls, logger_adapter) -> None: Set logger config """ cls._logger_adapter = logger_adapter - with cls._loggers_lock: + cls._loggers_lock.acquire() + try: # update all loggers cls._loggers = { name: cls._logger_adapter.get_logger(name) for name in cls._loggers } + finally: + cls._loggers_lock.release() @classmethod def get_logger_adapter(cls) -> LoggerAdapter: @@ -85,28 +90,32 @@ def get_logger(cls, name: str) -> Logger: """ logger = cls._loggers.get(name) if logger is None: - with cls._loggers_lock: + cls._loggers_lock.acquire() + try: if name not in cls._loggers: cls._loggers[name] = cls._logger_adapter.get_logger(name) logger = cls._loggers[name] + finally: + cls._loggers_lock.release() + return logger @classmethod - def get_level(cls) -> LoggerLevel: + def get_level(cls) -> Level: """ Get the current logging level. Returns: - LoggerLevel: The current logging level. + Level: The current logging level. """ return cls._logger_adapter.level @classmethod - def set_level(cls, level: LoggerLevel) -> None: + def set_level(cls, level: Level) -> None: """ Set the logging level. Args: - level (LoggerLevel): The logging level to set. + level (Level): The logging level to set. """ cls._logger_adapter.level = level diff --git a/tests/common/tets_url.py b/tests/common/tets_url.py index 736f870..7252500 100644 --- a/tests/common/tets_url.py +++ b/tests/common/tets_url.py @@ -26,7 +26,7 @@ def test_str_to_url(self): ) self.assertEqual("http", url_0.protocol) self.assertEqual("www.facebook.com", url_0.host) - self.assertEqual(None, url_0.port) + self.assertEqual(0, url_0.port) self.assertEqual("friends", url_0.path) self.assertEqual("value1", url_0.get_parameter("param1")) self.assertEqual("value2", url_0.get_parameter("param2")) @@ -50,7 +50,7 @@ def test_str_to_url(self): ) self.assertEqual("http", url_3.protocol) self.assertEqual("www.facebook.com", url_3.host) - self.assertEqual(None, url_3.port) + self.assertEqual(0, url_3.port) self.assertEqual("friends", url_3.path) self.assertEqual("value1", url_3.get_parameter("param1")) self.assertEqual("value2", url_3.get_parameter("param2")) diff --git a/tests/logger/test_internal_logger.py b/tests/logger/test_internal_logger.py index 0150997..2e53998 100644 --- a/tests/logger/test_internal_logger.py +++ b/tests/logger/test_internal_logger.py @@ -15,7 +15,7 @@ # limitations under the License. import unittest -from dubbo.common.constants import LoggerLevel +from dubbo.common.constants import Level from dubbo.config import LoggerConfig from dubbo.logger.internal.logger_adapter import InternalLoggerAdapter @@ -25,7 +25,7 @@ class TestInternalLogger(unittest.TestCase): def test_log(self): logger_adapter = InternalLoggerAdapter(config=LoggerConfig().get_url()) logger = logger_adapter.get_logger("test") - logger.log(LoggerLevel.INFO, "test log") + logger.log(Level.INFO, "test log") logger.debug("test debug") logger.info("test info") logger.warning("test warning") @@ -38,11 +38,11 @@ def test_log(self): logger.exception("test exception") # test different default logger level - logger_adapter.level = LoggerLevel.INFO + logger_adapter.level = Level.INFO logger.debug("debug can't be logged") - logger_adapter.level = LoggerLevel.WARNING + logger_adapter.level = Level.WARNING logger.info("info can't be logged") - logger_adapter.level = LoggerLevel.ERROR + logger_adapter.level = Level.ERROR logger.warning("warning can't be logged") diff --git a/tests/logger/test_logger_factory.py b/tests/logger/test_logger_factory.py index 03446b1..c33204a 100644 --- a/tests/logger/test_logger_factory.py +++ b/tests/logger/test_logger_factory.py @@ -15,7 +15,8 @@ # limitations under the License. import unittest -from dubbo.common.constants import LoggerConstants, LoggerLevel +from dubbo.common.constants import logger as logger_constants +from dubbo.common.constants.logger import Level from dubbo.config import LoggerConfig from dubbo.logger import loggerFactory from dubbo.logger.internal.logger_adapter import InternalLoggerAdapter @@ -30,18 +31,18 @@ def test_without_config(self): def test_with_config(self): # Test the case where config is used - config = LoggerConfig() + config = LoggerConfig("logging") config.init() logger = loggerFactory.get_logger("test_factory") logger.info("info log -> with_config ") url = config.get_url() - url.add_parameter(LoggerConstants.LOGGER_FILE_ENABLED_KEY, True) + url.add_parameter(logger_constants.FILE_ENABLED_KEY, True) loggerFactory.set_logger_adapter(InternalLoggerAdapter(url)) - loggerFactory.set_level(LoggerLevel.DEBUG) + loggerFactory.set_level(Level.DEBUG) logger.debug("debug log -> with_config -> open file") - url.add_parameter(LoggerConstants.LOGGER_CONSOLE_ENABLED_KEY, False) + url.add_parameter(logger_constants.CONSOLE_ENABLED_KEY, False) loggerFactory.set_logger_adapter(InternalLoggerAdapter(url)) - loggerFactory.set_level(LoggerLevel.DEBUG) + loggerFactory.set_level(Level.DEBUG) logger.debug("debug log -> with_config -> lose console") From 9206c5a8b5bc28ac6327faf93860bf9b15af030e Mon Sep 17 00:00:00 2001 From: zaki Date: Mon, 17 Jun 2024 12:00:59 +0800 Subject: [PATCH 20/32] fix: fix ci --- tests/common/extension/test_logger_extension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/common/extension/test_logger_extension.py b/tests/common/extension/test_logger_extension.py index b5eda81..350be07 100644 --- a/tests/common/extension/test_logger_extension.py +++ b/tests/common/extension/test_logger_extension.py @@ -25,7 +25,7 @@ def test_logger_extension(self): # Test the get_logger_adapter method. logger_adapter = extension.get_logger_adapter( - "internal", LoggerConfig().get_url() + "logging", LoggerConfig("logging").get_url() ) # Test logger_adapter methods. From 345eafea92bd88a94dd9aec2eb3490b1034cc4cb Mon Sep 17 00:00:00 2001 From: zaki Date: Mon, 17 Jun 2024 12:03:07 +0800 Subject: [PATCH 21/32] fix: fix ci --- tests/logger/test_internal_logger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/logger/test_internal_logger.py b/tests/logger/test_internal_logger.py index 2e53998..91fbbb5 100644 --- a/tests/logger/test_internal_logger.py +++ b/tests/logger/test_internal_logger.py @@ -15,7 +15,7 @@ # limitations under the License. import unittest -from dubbo.common.constants import Level +from dubbo.common.constants.logger import Level from dubbo.config import LoggerConfig from dubbo.logger.internal.logger_adapter import InternalLoggerAdapter @@ -23,7 +23,7 @@ class TestInternalLogger(unittest.TestCase): def test_log(self): - logger_adapter = InternalLoggerAdapter(config=LoggerConfig().get_url()) + logger_adapter = InternalLoggerAdapter(config=LoggerConfig("logging").get_url()) logger = logger_adapter.get_logger("test") logger.log(Level.INFO, "test log") logger.debug("test debug") From 05ec4db29af960ef295d7d38aab8dfcf5c1c29fd Mon Sep 17 00:00:00 2001 From: zaki Date: Sat, 29 Jun 2024 13:40:36 +0800 Subject: [PATCH 22/32] feat: update something about client --- .flake8 | 2 +- .../python-lint-and-license-check.yml | 6 - dubbo/__init__.py | 1 - dubbo/_dubbo.py | 160 ++++++++++++++++- .../{logger/internal => callable}/__init__.py | 0 dubbo/callable/rpc_callable.py | 78 +++++++++ dubbo/callable/rpc_callable_factory.py | 37 ++++ dubbo/client/client.py | 112 +++++++++++- dubbo/common/constants/common_constants.py | 37 ++++ .../{logger.py => logger_constants.py} | 3 - dubbo/common/constants/type_constants.py | 19 ++ dubbo/common/extension/logger_extension.py | 68 -------- dubbo/common/url.py | 57 ++++-- .../extension => compressor}/__init__.py | 1 - .../{imports.py => compressor/compressor.py} | 6 +- dubbo/config/__init__.py | 6 +- dubbo/config/application_config.py | 45 +++++ dubbo/config/consumer_config.py | 30 ++++ dubbo/config/logger_config.py | 78 ++++----- dubbo/config/method_config.py | 67 +++++++ dubbo/config/protocol_config.py | 30 ++++ dubbo/config/reference_config.py | 74 ++++++++ dubbo/extension/__init__.py | 20 +++ dubbo/extension/extension_loader.py | 89 ++++++++++ dubbo/extension/registry.py | 64 +++++++ dubbo/logger/__init__.py | 5 - dubbo/logger/logger.py | 2 +- dubbo/logger/logger_factory.py | 28 +-- dubbo/logger/logging/__init__.py | 17 ++ dubbo/logger/logging/formatter.py | 86 +++++++++ dubbo/logger/{internal => logging}/logger.py | 8 +- .../{internal => logging}/logger_adapter.py | 53 +++--- dubbo/loop/__init__.py | 58 ++++++ dubbo/loop/loop_manger.py | 111 ++++++++++++ dubbo/protocol/__init__.py | 15 ++ dubbo/protocol/invocation.py | 78 +++++++++ dubbo/protocol/invoker.py | 35 ++++ dubbo/protocol/protocol.py | 30 ++++ dubbo/protocol/result.py | 19 ++ dubbo/protocol/triple/__init__.py | 15 ++ dubbo/protocol/triple/tri_decoder.py | 152 ++++++++++++++++ dubbo/protocol/triple/tri_invoker.py | 37 ++++ dubbo/protocol/triple/tri_stream.py | 86 +++++++++ dubbo/protocol/triple/triple_protocol.py | 28 +++ dubbo/remoting/__init__.py | 15 ++ dubbo/remoting/aio/__init__.py | 15 ++ dubbo/remoting/aio/aio_transporter.py | 91 ++++++++++ dubbo/remoting/aio/http2_protocol.py | 165 ++++++++++++++++++ dubbo/remoting/transporter.py | 40 +++++ dubbo/serialization/__init__.py | 15 ++ dubbo/serialization/serialization.py | 83 +++++++++ requirements.txt | 1 + tests/logger/test_logger_factory.py | 15 +- ...ernal_logger.py => test_logging_logger.py} | 8 +- tests/loop/__init__.py | 15 ++ tests/loop/test_loop_manger.py | 37 ++++ tests/test_client.py | 81 +++++++++ tests/test_server.py | 43 +++++ 58 files changed, 2370 insertions(+), 207 deletions(-) rename dubbo/{logger/internal => callable}/__init__.py (100%) create mode 100644 dubbo/callable/rpc_callable.py create mode 100644 dubbo/callable/rpc_callable_factory.py create mode 100644 dubbo/common/constants/common_constants.py rename dubbo/common/constants/{logger.py => logger_constants.py} (95%) create mode 100644 dubbo/common/constants/type_constants.py delete mode 100644 dubbo/common/extension/logger_extension.py rename dubbo/{common/extension => compressor}/__init__.py (91%) rename dubbo/{imports.py => compressor/compressor.py} (85%) create mode 100644 dubbo/config/application_config.py create mode 100644 dubbo/config/consumer_config.py create mode 100644 dubbo/config/method_config.py create mode 100644 dubbo/config/protocol_config.py create mode 100644 dubbo/config/reference_config.py create mode 100644 dubbo/extension/__init__.py create mode 100644 dubbo/extension/extension_loader.py create mode 100644 dubbo/extension/registry.py create mode 100644 dubbo/logger/logging/__init__.py create mode 100644 dubbo/logger/logging/formatter.py rename dubbo/logger/{internal => logging}/logger.py (93%) rename dubbo/logger/{internal => logging}/logger_adapter.py (76%) create mode 100644 dubbo/loop/__init__.py create mode 100644 dubbo/loop/loop_manger.py create mode 100644 dubbo/protocol/__init__.py create mode 100644 dubbo/protocol/invocation.py create mode 100644 dubbo/protocol/invoker.py create mode 100644 dubbo/protocol/protocol.py create mode 100644 dubbo/protocol/result.py create mode 100644 dubbo/protocol/triple/__init__.py create mode 100644 dubbo/protocol/triple/tri_decoder.py create mode 100644 dubbo/protocol/triple/tri_invoker.py create mode 100644 dubbo/protocol/triple/tri_stream.py create mode 100644 dubbo/protocol/triple/triple_protocol.py create mode 100644 dubbo/remoting/__init__.py create mode 100644 dubbo/remoting/aio/__init__.py create mode 100644 dubbo/remoting/aio/aio_transporter.py create mode 100644 dubbo/remoting/aio/http2_protocol.py create mode 100644 dubbo/remoting/transporter.py create mode 100644 dubbo/serialization/__init__.py create mode 100644 dubbo/serialization/serialization.py rename tests/logger/{test_internal_logger.py => test_logging_logger.py} (87%) create mode 100644 tests/loop/__init__.py create mode 100644 tests/loop/test_loop_manger.py create mode 100644 tests/test_client.py create mode 100644 tests/test_server.py diff --git a/.flake8 b/.flake8 index 233cd14..44d4fa3 100644 --- a/.flake8 +++ b/.flake8 @@ -24,6 +24,6 @@ exclude = per-file-ignores = __init__.py:F401 # module level import not at top of file - dubbo/imports.py:F401 + dubbo/_imports.py:F401 # module level import not at top of file dubbo/common/extension/logger_extension.py:E402 diff --git a/.github/workflows/python-lint-and-license-check.yml b/.github/workflows/python-lint-and-license-check.yml index 1cbb9cd..b552112 100644 --- a/.github/workflows/python-lint-and-license-check.yml +++ b/.github/workflows/python-lint-and-license-check.yml @@ -19,12 +19,6 @@ jobs: pip install flake8 flake8 . - - name: Type check with MyPy - run: | - # fail if there are any MyPy errors - pip install mypy - mypy ./dubbo - check-license: runs-on: ubuntu-latest steps: diff --git a/dubbo/__init__.py b/dubbo/__init__.py index b31a846..a5a99ea 100644 --- a/dubbo/__init__.py +++ b/dubbo/__init__.py @@ -13,6 +13,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import dubbo.imports from ._dubbo import Dubbo diff --git a/dubbo/_dubbo.py b/dubbo/_dubbo.py index 4f7a73b..05a096f 100644 --- a/dubbo/_dubbo.py +++ b/dubbo/_dubbo.py @@ -13,9 +13,165 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import threading +from typing import Dict, List + +from dubbo.config import (ApplicationConfig, ConsumerConfig, LoggerConfig, + ProtocolConfig) +from dubbo.logger.logger_factory import loggerFactory + +logger = loggerFactory.get_logger(__name__) class Dubbo: - """The entry point of dubbo-python framework.(singleton)""" - pass + # class variable + _instance = None + _ins_lock = threading.Lock() + + # instance variable + # common + _application: ApplicationConfig + _protocols: Dict[str, ProtocolConfig] + _logger: LoggerConfig + # consumer + _consumer: ConsumerConfig + # provider + # .... + + __slots__ = ["_application", "_protocols", "_logger", "_consumer"] + + def __new__(cls, *args, **kwargs): + # dubbo object is singleton + if cls._instance is None: + with cls._ins_lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + # common + self._application = ApplicationConfig.default_config() + self._protocols = {} + self._logger = LoggerConfig.default_config() + # consumer + self._consumer = ConsumerConfig.default_config() + # provider + # TODO add provider config + + # @overload + # def new_client( + # self, reference: str, consumer: Optional[ConsumerConfig] = None + # ) -> Client: ... + # + # @overload + # def new_client( + # self, + # reference: ReferenceConfig, + # consumer: Optional[ConsumerConfig] = None, + # ) -> Client: ... + # + # def new_client( + # self, + # reference: Union[str, ReferenceConfig], + # consumer: Optional[ConsumerConfig] = None, + # ) -> Client: + # """ + # Create a new client + # Args: + # reference: reference value + # consumer: consumer config + # Returns: + # Client: A new instance of Client + # """ + # if isinstance(reference, str): + # reference = ReferenceConfig() + # elif isinstance(reference, ReferenceConfig): + # reference = reference + # else: + # raise TypeError( + # "reference must be a string or an instance of ReferenceConfig" + # ) + # consumer_config = consumer or self._consumer.clone() + # return Client(reference, consumer_config) + + def new_server(self): + """ + Create a new server + """ + pass + + def _init(self): + pass + + def start(self): + pass + + def destroy(self): + pass + + def with_application(self, application_config: ApplicationConfig) -> "Dubbo": + """ + Set application config + Args: + application_config: new application config + Returns: + self: Dubbo instance + """ + if application_config is None or not isinstance( + application_config, ApplicationConfig + ): + raise ValueError("application must be an instance of ApplicationConfig") + self._application = application_config + return self + + def with_protocol(self, protocol_config: ProtocolConfig) -> "Dubbo": + """ + Set protocol config + Args: + protocol_config: new protocol config + Returns: + self: Dubbo instance + """ + if protocol_config is None or not isinstance(protocol_config, ProtocolConfig): + raise ValueError("protocol must be an instance of ProtocolConfig") + self._protocols[protocol_config.name] = protocol_config + return self + + def with_protocols(self, protocol_configs: List[ProtocolConfig]) -> "Dubbo": + """ + Set protocol config + Args: + protocol_configs: new protocol configs + Returns: + self: Dubbo instance + """ + for protocol_config in protocol_configs: + self.with_protocol(protocol_config) + return self + + def with_logger(self, logger_config: LoggerConfig) -> "Dubbo": + """ + Set logger config + Args: + logger_config: new logger config + Returns: + self: Dubbo instance + """ + if logger_config is None or not isinstance(logger_config, LoggerConfig): + raise ValueError("logger must be an instance of LoggerConfig") + self._logger = logger_config + return self + + def with_consumer(self, consumer_config: ConsumerConfig) -> "Dubbo": + """ + Set consumer config + Args: + consumer_config: new consumer config + Returns: + self: Dubbo instance + """ + if consumer_config is None or not isinstance(consumer_config, ConsumerConfig): + raise ValueError("consumer must be an instance of ConsumerConfig") + self._consumer = consumer_config + return self diff --git a/dubbo/logger/internal/__init__.py b/dubbo/callable/__init__.py similarity index 100% rename from dubbo/logger/internal/__init__.py rename to dubbo/callable/__init__.py diff --git a/dubbo/callable/rpc_callable.py b/dubbo/callable/rpc_callable.py new file mode 100644 index 0000000..5f6405c --- /dev/null +++ b/dubbo/callable/rpc_callable.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any + +from dubbo.common.constants import common_constants +from dubbo.common.url import URL +from dubbo.protocol.invocation import RpcInvocation +from dubbo.protocol.invoker import Invoker + + +class RpcCallable: + + def __init__(self, invoker: Invoker, url: URL): + self._invoker = invoker + self._url = url + self._service_name = self._url.path or "" + method_url = self._url.get_attribute(common_constants.METHOD_KEY) + self._method_name = method_url.get_parameter(common_constants.METHOD_KEY) or "" + self._call_type = method_url.get_parameter(common_constants.TYPE_CALL) + self._req_serializer = ( + method_url.get_attribute(common_constants.SERIALIZATION) or None + ) + self._res_serializer = ( + method_url.get_attribute(common_constants.SERIALIZATION) or None + ) + + def _do_call(self, argument: Any): + """ + Real call method. + """ + if ( + self._call_type == common_constants.CALL_CLIENT_STREAM + and not inspect.isgeneratorfunction(argument) + ): + raise ValueError( + "Invalid argument: The provided argument must be a generator function " + ) + elif ( + self._call_type == common_constants.CALL_UNARY + and inspect.isgeneratorfunction(argument) + ): + raise ValueError( + "Invalid argument: The provided argument must be a normal function" + ) + + # Create a new RpcInvocation object. + invocation = RpcInvocation( + self._service_name, + self._method_name, + argument, + self._req_serializer, + self._res_serializer, + ) + # Do invoke. + return self._invoker.invoke(invocation) + + def __call__(self, argument: Any): + return self._do_call(argument) + + +class AsyncRpcCallable: + + async def __call__(self, *args, **kwargs): + pass diff --git a/dubbo/callable/rpc_callable_factory.py b/dubbo/callable/rpc_callable_factory.py new file mode 100644 index 0000000..55edbba --- /dev/null +++ b/dubbo/callable/rpc_callable_factory.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dubbo.callable.rpc_callable import RpcCallable +from dubbo.common.url import URL +from dubbo.protocol.invoker import Invoker + + +class RpcCallableFactory: + + def get_proxy(self, url: URL, invoker: Invoker) -> RpcCallable: + """ + Get the callable object. + Args: + url (URL): The URL. + invoker (Invoker): The invoker object. + """ + raise NotImplementedError("get_proxy() is not implemented") + + +class DefaultRpcCallableFactory(RpcCallableFactory): + + def get_proxy(self, url: URL, invoker: Invoker) -> RpcCallable: + pass diff --git a/dubbo/client/client.py b/dubbo/client/client.py index e4eaefd..f66a523 100644 --- a/dubbo/client/client.py +++ b/dubbo/client/client.py @@ -13,11 +13,119 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dubbo.logger import loggerFactory +from typing import Optional, Union + +from dubbo.callable.rpc_callable import AsyncRpcCallable, RpcCallable +from dubbo.callable.rpc_callable_factory import DefaultRpcCallableFactory +from dubbo.common.constants import common_constants +from dubbo.common.constants.type_constants import (DeserializingFunction, + SerializingFunction) +from dubbo.common.url import URL +from dubbo.config import ConsumerConfig, ReferenceConfig +from dubbo.logger.logger_factory import loggerFactory logger = loggerFactory.get_logger(__name__) class Client: - pass + _consumer: ConsumerConfig + _reference: ReferenceConfig + + __slots__ = ["_consumer", "_reference"] + + def __init__( + self, reference: ReferenceConfig, consumer: Optional[ConsumerConfig] = None + ): + self._reference = reference + self._consumer = consumer or ConsumerConfig.default_config() + + def unary( + self, + method_name: str, + req_serializer: Optional[SerializingFunction] = None, + resp_deserializer: Optional[DeserializingFunction] = None, + ) -> Union[RpcCallable, AsyncRpcCallable]: + return self._callable( + common_constants.CALL_UNARY, method_name, req_serializer, resp_deserializer + ) + + def client_stream( + self, + method_name: str, + req_serializer: Optional[SerializingFunction] = None, + resp_deserializer: Optional[DeserializingFunction] = None, + ) -> Union[RpcCallable, AsyncRpcCallable]: + return self._callable( + common_constants.CALL_CLIENT_STREAM, + method_name, + req_serializer, + resp_deserializer, + ) + + def server_stream( + self, + method_name: str, + req_serializer: Optional[SerializingFunction] = None, + resp_deserializer: Optional[DeserializingFunction] = None, + ) -> Union[RpcCallable, AsyncRpcCallable]: + return self._callable( + common_constants.CALL_SERVER_STREAM, + method_name, + req_serializer, + resp_deserializer, + ) + + def bidi_stream( + self, + method_name: str, + req_serializer: Optional[SerializingFunction] = None, + resp_deserializer: Optional[DeserializingFunction] = None, + ) -> Union[RpcCallable, AsyncRpcCallable]: + return self._callable( + common_constants.CALL_BIDI_STREAM, + method_name, + req_serializer, + resp_deserializer, + ) + + def _callable( + self, + call_type: str, + method_name: str, + req_serializer: Optional[SerializingFunction] = None, + resp_deserializer: Optional[DeserializingFunction] = None, + ) -> Union[RpcCallable, AsyncRpcCallable]: + """ + Generate a callable for the given method + Args: + call_type: call type + method_name: method name + req_serializer: request serializer, args: Any, return: bytes + resp_deserializer: response deserializer, args: bytes, return: Any + Returns: + RpcCallable: The callable object + """ + # get invoker + invoker = self._reference.get_invoker() + url = invoker.get_url() + + method_url = URL( + method_name, + common_constants.LOCALHOST_KEY, + parameters={ + common_constants.METHOD_KEY: method_name, + common_constants.TYPE_CALL: call_type, + }, + ) + # add attributes + method_url.add_attribute(common_constants.SERIALIZATION, req_serializer) + method_url.add_attribute(common_constants.DESERIALIZATION, resp_deserializer) + + # put the method url into the invoker url + url.add_attribute(method_name, method_url) + + # create callable + rpc_callable = DefaultRpcCallableFactory().get_proxy(invoker, url) + + return rpc_callable diff --git a/dubbo/common/constants/common_constants.py b/dubbo/common/constants/common_constants.py new file mode 100644 index 0000000..c985045 --- /dev/null +++ b/dubbo/common/constants/common_constants.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +TRIPLE = "tri" + +LOCALHOST_KEY = "localhost" +LOCALHOST_VALUE = "127.0.0.1" + +TYPE_CALL = "call" +CALL_UNARY = "unary" +CALL_CLIENT_STREAM = "client-stream" +CALL_SERVER_STREAM = "server-stream" +CALL_BIDI_STREAM = "bidi-stream" + +SERIALIZATION = "serialization" +DESERIALIZATION = "deserialization" + +SERVER_KEY = "server" +METHOD_KEY = "method" + + +TRUE_VALUE = "true" +FALSE_VALUE = "false" diff --git a/dubbo/common/constants/logger.py b/dubbo/common/constants/logger_constants.py similarity index 95% rename from dubbo/common/constants/logger.py rename to dubbo/common/constants/logger_constants.py index b68cab8..40ae17e 100644 --- a/dubbo/common/constants/logger.py +++ b/dubbo/common/constants/logger_constants.py @@ -59,15 +59,12 @@ class FileRotateType(enum.Enum): # global config LEVEL_KEY = "logger.level" DRIVER_KEY = "logger.driver" -FORMAT_KEY = "logger.format" # console config CONSOLE_ENABLED_KEY = "logger.console.enable" -CONSOLE_FORMAT_KEY = "logger.console.format" # file logger FILE_ENABLED_KEY = "logger.file.enable" -FILE_FORMAT_KEY = "logger.file.format" FILE_DIR_KEY = "logger.file.dir" FILE_NAME_KEY = "logger.file.name" FILE_ROTATE_KEY = "logger.file.rotate" diff --git a/dubbo/common/constants/type_constants.py b/dubbo/common/constants/type_constants.py new file mode 100644 index 0000000..bb332be --- /dev/null +++ b/dubbo/common/constants/type_constants.py @@ -0,0 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable + +SerializingFunction = Callable[[Any], bytes] +DeserializingFunction = Callable[[bytes], Any] diff --git a/dubbo/common/extension/logger_extension.py b/dubbo/common/extension/logger_extension.py deleted file mode 100644 index 71c3470..0000000 --- a/dubbo/common/extension/logger_extension.py +++ /dev/null @@ -1,68 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This module provides an extension point for logger adapters. -""" -from typing import Dict - -from dubbo.common.url import URL -from dubbo.logger import LoggerAdapter - -# A dictionary to store all the logger adapters. key: name, value: logger adapter class -_logger_adapter_dict: Dict[str, type[LoggerAdapter]] = {} - - -def register_logger_adapter(name: str): - """ - A decorator to register a logger class to the logger extension point. - - This function returns a decorator that registers the decorated class - as a logger adapter under the specified name. - - Args: - name (str): The name to register the logger adapter under. - - Returns: - Callable[[Type[LoggerAdapter]], Type[LoggerAdapter]]: - A decorator function that registers the logger class. - """ - - def wrapper(cls): - _logger_adapter_dict[name] = cls - return cls - - return wrapper - - -def get_logger_adapter(name: str, config: URL) -> LoggerAdapter: - """ - Get a logger adapter instance by name. - - This function retrieves a logger adapter class by its registered name and - instantiates it with the provided arguments. - - Args: - name (str): The name of the logger adapter to retrieve. - config (URL): The config of the logger adapter to retrieve. - - Returns: - LoggerAdapter: An instance of the requested logger adapter. - Raises: - KeyError: If no logger adapter is registered under the provided name. - """ - logger_adapter = _logger_adapter_dict[name] - return logger_adapter(config) diff --git a/dubbo/common/url.py b/dubbo/common/url.py index 64dcf4c..b4e65a0 100644 --- a/dubbo/common/url.py +++ b/dubbo/common/url.py @@ -20,14 +20,15 @@ class URL: """ URL - Uniform Resource Locator. - Attributes: - _protocol (str): The protocol of the URL. - _host (str): The host of the URL. - _port (int): The port number of the URL. - _username (str): The username for URL authentication. - _password (str): The password for URL authentication. - _path (str): The path of the URL. - _parameters (Dict[str, str]): The query parameters of the URL. + Args: + protocol (str): The protocol of the URL. + host (str): The host of the URL. + port (int): The port number of the URL. + username (str): The username for URL authentication. + password (str): The password for URL authentication. + path (str): The path of the URL. + parameters (Dict[str, str]): The query parameters of the URL. + attributes (Dict[str, Any]): The attributes of the URL. (non-transferable) url example: - http://www.facebook.com/friends?param1=value1¶m2=value2 @@ -36,14 +37,6 @@ class URL: - registry://192.168.1.7:9090/org.apache.dubbo.service1?param1=value1¶m2=value2 """ - _protocol: str - _username: str - _password: str - _host: str - _port: int - _path: str - _parameters: Dict[str, str] - def __init__( self, protocol: str, @@ -53,6 +46,7 @@ def __init__( password: str = "", path: str = "", parameters: Optional[Dict[str, str]] = None, + attributes: Optional[Dict[str, Any]] = None, ): self._protocol = protocol self._host = host @@ -63,6 +57,7 @@ def __init__( self._password = password self._path = path self._parameters = parameters or {} + self._attributes = attributes or {} @property def protocol(self) -> str: @@ -238,6 +233,34 @@ def add_parameter(self, key: str, value: Any) -> None: """ self._parameters[key] = str(value) if value is not None else "" + @property + def attributes(self): + """ + Gets the attributes of the URL. + Returns: + Dict[str, Any]: The attributes of the URL. + """ + return self._attributes + + def add_attribute(self, key: str, value: Any) -> None: + """ + ADDs an attribute to the URL. + Args: + key (str): The attribute name. + value (Any): The attribute value. + """ + self._attributes[key] = value + + def get_attribute(self, key: str) -> Optional[Any]: + """ + Gets an attribute from the URL. + Args: + key (str): The attribute name. + Returns: + Any: The attribute value. If the attribute does not exist, returns None. + """ + return self._attributes.get(key, None) + def build_string(self, encode: bool = False) -> str: """ Generates the URL string based on the current components. @@ -292,7 +315,7 @@ def value_of(cls, url: str, encoded: bool = False) -> "URL": URL: The created URL object. """ if not url: - raise ValueError() + raise ValueError("URL string cannot be empty or None.") # If the URL is encoded, decode it if encoded: diff --git a/dubbo/common/extension/__init__.py b/dubbo/compressor/__init__.py similarity index 91% rename from dubbo/common/extension/__init__.py rename to dubbo/compressor/__init__.py index c3ee8fe..bcba37a 100644 --- a/dubbo/common/extension/__init__.py +++ b/dubbo/compressor/__init__.py @@ -13,4 +13,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .logger_extension import get_logger_adapter, register_logger_adapter diff --git a/dubbo/imports.py b/dubbo/compressor/compressor.py similarity index 85% rename from dubbo/imports.py rename to dubbo/compressor/compressor.py index 6d4c314..2edbc85 100644 --- a/dubbo/imports.py +++ b/dubbo/compressor/compressor.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utilizing the mechanism of module loading to complete the registration of plugins.""" -import dubbo.logger.internal.logger_adapter +class DeCompressor: + + def decompress(self, data: bytes) -> bytes: + pass diff --git a/dubbo/config/__init__.py b/dubbo/config/__init__.py index b6b51a2..63d9535 100644 --- a/dubbo/config/__init__.py +++ b/dubbo/config/__init__.py @@ -13,4 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .logger_config import ConsoleLoggerConfig, FileLoggerConfig, LoggerConfig +from .application_config import ApplicationConfig +from .consumer_config import ConsumerConfig +from .logger_config import FileLoggerConfig, LoggerConfig +from .protocol_config import ProtocolConfig +from .reference_config import ReferenceConfig diff --git a/dubbo/config/application_config.py b/dubbo/config/application_config.py new file mode 100644 index 0000000..8ee0806 --- /dev/null +++ b/dubbo/config/application_config.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class ApplicationConfig: + """ + Application configuration. + Attributes: + _name(str): Application name + _version(str): Application version + _owner(str): Application owner + _organization(str): Application organization (BU) + _environment(str): Application environment, e.g. dev, test or production + """ + + _name: str + _version: str + _owner: str + _organization: str + _environment: str + + def clone(self) -> "ApplicationConfig": + """ + Clone the current configuration. + Returns: + ApplicationConfig: A new instance of ApplicationConfig. + """ + return ApplicationConfig() + + @classmethod + def default_config(cls): + return cls() diff --git a/dubbo/config/consumer_config.py b/dubbo/config/consumer_config.py new file mode 100644 index 0000000..5037efe --- /dev/null +++ b/dubbo/config/consumer_config.py @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class ConsumerConfig: + + def clone(self) -> "ConsumerConfig": + """ + Clone the current configuration. + Returns: + ConsumerConfig: A new instance of ConsumerConfig. + """ + return ConsumerConfig() + + @classmethod + def default_config(cls): + return cls() diff --git a/dubbo/config/logger_config.py b/dubbo/config/logger_config.py index 43035b8..d91d5ba 100644 --- a/dubbo/config/logger_config.py +++ b/dubbo/config/logger_config.py @@ -16,30 +16,12 @@ from dataclasses import dataclass from typing import Dict, Optional -from dubbo.common import extension -from dubbo.common.constants import logger as logger_constants -from dubbo.common.constants.logger import FileRotateType, Level +from dubbo.common.constants import logger_constants as logger_constants +from dubbo.common.constants.logger_constants import FileRotateType, Level from dubbo.common.url import URL -from dubbo.logger import loggerFactory - - -@dataclass -class ConsoleLoggerConfig: - """ - Console logger configuration. - Attributes: - console_format(Optional[str]): console format, if null, use global format. - """ - - console_format: Optional[str] = None - - def check(self): - pass - - def dict(self) -> Dict[str, str]: - return { - logger_constants.CONSOLE_FORMAT_KEY: self.console_format or "", - } +from dubbo.extension import extensionLoader +from dubbo.logger import LoggerAdapter +from dubbo.logger.logger_factory import loggerFactory @dataclass @@ -73,7 +55,6 @@ def check(self) -> None: def dict(self) -> Dict[str, str]: return { - logger_constants.FILE_FORMAT_KEY: self.file_formatter or "", logger_constants.FILE_DIR_KEY: self.file_dir, logger_constants.FILE_NAME_KEY: self.file_name, logger_constants.FILE_ROTATE_KEY: self.rotate.value, @@ -90,9 +71,7 @@ class LoggerConfig: Attributes: _driver(str): logger driver type. _level(Level): logger level. - _formatter(Optional[str]): logger formatter. _console_enabled(bool): logger console enabled. - _console_config(ConsoleLoggerConfig): logger console config. _file_enabled(bool): logger file enabled. _file_config(FileLoggerConfig): logger file config. """ @@ -100,33 +79,34 @@ class LoggerConfig: # global _driver: str _level: Level - _formatter: Optional[str] # console _console_enabled: bool - _console_config: ConsoleLoggerConfig # file _file_enabled: bool _file_config: FileLoggerConfig + __slots__ = [ + "_driver", + "_level", + "_console_enabled", + "_console_config", + "_file_enabled", + "_file_config", + ] + def __init__( self, driver, - level=logger_constants.DEFAULT_LEVEL_VALUE, - formatter: Optional[str] = None, - console_enabled: bool = logger_constants.DEFAULT_CONSOLE_ENABLED_VALUE, - console_config: ConsoleLoggerConfig = ConsoleLoggerConfig(), - file_enabled: bool = logger_constants.DEFAULT_FILE_ENABLED_VALUE, - file_config: FileLoggerConfig = FileLoggerConfig(), + level, + console_enabled: bool, + file_enabled: bool, + file_config: FileLoggerConfig, ): # set global config self._driver = driver self._level = level - self._formatter = formatter # set console config self._console_enabled = console_enabled - self._console_config = console_config - if console_enabled: - self._console_config.check() # set file comfig self._file_enabled = file_enabled self._file_config = file_config @@ -138,10 +118,8 @@ def get_url(self) -> URL: parameters = { logger_constants.DRIVER_KEY: self._driver, logger_constants.LEVEL_KEY: self._level.value, - logger_constants.FORMAT_KEY: self._formatter or "", logger_constants.CONSOLE_ENABLED_KEY: str(self._console_enabled), logger_constants.FILE_ENABLED_KEY: str(self._file_enabled), - **self._console_config.dict(), **self._file_config.dict(), } @@ -149,5 +127,21 @@ def get_url(self) -> URL: def init(self): # get logger_adapter and initialize loggerFactory - logger_adapter = extension.get_logger_adapter(self._driver, self.get_url()) - loggerFactory.logger_adapter = logger_adapter + logger_adapter_class = extensionLoader.get_extension( + LoggerAdapter, self._driver + ) + logger_adapter = logger_adapter_class(self.get_url()) + loggerFactory.set_logger_adapter(logger_adapter) + + @classmethod + def default_config(cls): + """ + Get default logger configuration. + """ + return LoggerConfig( + driver=logger_constants.DEFAULT_DRIVER_VALUE, + level=logger_constants.DEFAULT_LEVEL_VALUE, + console_enabled=logger_constants.DEFAULT_CONSOLE_ENABLED_VALUE, + file_enabled=logger_constants.DEFAULT_FILE_ENABLED_VALUE, + file_config=FileLoggerConfig(), + ) diff --git a/dubbo/config/method_config.py b/dubbo/config/method_config.py new file mode 100644 index 0000000..f6c2dcd --- /dev/null +++ b/dubbo/config/method_config.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + + +class MethodConfig: + """ + MethodConfig is a configuration class for a method. + Attributes: + _interface_name (str): The name of the interface. + _name (str): The name of the method. + _request_serialize (Optional[Callable[..., Any]]): The request serialization function. + _response_deserialize (Optional[Callable[..., Any]]): The response deserialization function. + """ + + _interface_name: str + _name: str + _request_serialize: Optional[Callable[..., Any]] + _response_deserialize: Optional[Callable[..., Any]] + + __slots__ = [ + "_interface_name", + "_name", + "_request_serialize", + "_response_deserialize", + ] + + def __init__( + self, + interface_name: str, + name: str, + request_serialize: Optional[Callable[..., Any]] = None, + response_deserialize: Optional[Callable[..., Any]] = None, + ): + self._interface_name = interface_name + self._name = name + self._request_serialize = request_serialize + self._response_deserialize = response_deserialize + + @property + def interface_name(self) -> str: + return self._interface_name + + @property + def name(self) -> str: + return self._name + + @property + def request_serialize(self) -> Optional[Callable[..., Any]]: + return self._request_serialize + + @property + def response_deserialize(self) -> Optional[Callable[..., Any]]: + return self._response_deserialize diff --git a/dubbo/config/protocol_config.py b/dubbo/config/protocol_config.py new file mode 100644 index 0000000..d629e1f --- /dev/null +++ b/dubbo/config/protocol_config.py @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class ProtocolConfig: + + _name: str + + __slots__ = ["_name"] + + @property + def name(self) -> str: + return self._name + + @name.setter + def name(self, value: str): + self._name = value diff --git a/dubbo/config/reference_config.py b/dubbo/config/reference_config.py new file mode 100644 index 0000000..fd30d8a --- /dev/null +++ b/dubbo/config/reference_config.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +from typing import List, Optional + +from dubbo.callable.rpc_callable_factory import RpcCallableFactory +from dubbo.common.url import URL +from dubbo.config.method_config import MethodConfig +from dubbo.extension import extensionLoader +from dubbo.protocol.invoker import Invoker +from dubbo.protocol.protocol import Protocol + + +class ReferenceConfig: + + _interface_name: str + _check: bool + _url: str + _protocol: str + _methods: List[MethodConfig] + + _global_lock: threading.Lock + _initialized: bool + _destroyed: bool + _protocol_ins: Optional[Protocol] + _invoker: Optional[Invoker] + _proxy_factory: Optional[RpcCallableFactory] + + def __init__( + self, + interface_name: str, + check: bool, + url: str, + protocol: str, + methods: Optional[List[MethodConfig]] = None, + ): + self._initialized = False + self._global_lock = threading.Lock() + self._destroyed = False + self._interface_name = interface_name + self._url = url + self._protocol = protocol + self._methods = methods or [] + + def get_invoker(self): + if not self._invoker: + self._do_init() + return self._invoker + + def _do_init(self): + with self._global_lock: + if self._initialized: + return + + clazz = extensionLoader.get_extension(Protocol, self._protocol) + self._protocol_ins = clazz() + self._create_invoker() + self._initialized = True + + def _create_invoker(self): + self._invoker = self._protocol_ins.refer(URL.value_of(self._url)) diff --git a/dubbo/extension/__init__.py b/dubbo/extension/__init__.py new file mode 100644 index 0000000..8744a34 --- /dev/null +++ b/dubbo/extension/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dubbo.extension.extension_loader import \ + ExtensionLoader as _ExtensionLoader + +extensionLoader = _ExtensionLoader() diff --git a/dubbo/extension/extension_loader.py b/dubbo/extension/extension_loader.py new file mode 100644 index 0000000..3c96040 --- /dev/null +++ b/dubbo/extension/extension_loader.py @@ -0,0 +1,89 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import threading +from typing import Any + +from dubbo.extension import registry +from dubbo.logger.logger_factory import loggerFactory + +logger = loggerFactory.get_logger(__name__) + + +class ExtensionLoader: + + _instance = None + _ins_lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + with cls._ins_lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + self._registries = registry.get_all_extended_registry() + + def get_extension(self, superclass: Any, name: str) -> Any: + # Get the registry for the extension + extension_impls = self._registries.get(superclass) + err_msg = None + if not extension_impls: + err_msg = f"Extension {superclass} is not registered." + logger.error(err_msg) + raise ValueError(err_msg) + + # Get the full name of the class -> module.class + full_name = extension_impls.get(name) + if not full_name: + err_msg = f"Extension {superclass} with name {name} is not registered." + logger.error(err_msg) + raise ValueError(err_msg) + + module_name = class_name = None + try: + # Split the full name into module and class + module_name, class_name = full_name.rsplit(".", 1) + # Load the module + module = importlib.import_module(module_name) + # Get the class from the module + subclass = getattr(module, class_name) + if subclass: + # Check if the class is a subclass of the extension + if issubclass(subclass, superclass) and subclass is not superclass: + # Return the class + return subclass + else: + err_msg = f"Class {class_name} does not inherit from {superclass}." + else: + err_msg = f"Class {class_name} not found in module {module_name}" + + if err_msg: + # If there is an error message, raise an exception + raise Exception(err_msg) + except ImportError as e: + logger.exception(f"Module {module_name} could not be imported.") + raise e + except AttributeError as e: + logger.exception(f"Class {class_name} not found in {module_name}.") + raise e + except Exception as e: + if err_msg: + logger.error(err_msg) + else: + logger.exception(f"An error occurred while loading {full_name}.") + raise e diff --git a/dubbo/extension/registry.py b/dubbo/extension/registry.py new file mode 100644 index 0000000..c0d0b12 --- /dev/null +++ b/dubbo/extension/registry.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import sys +from dataclasses import dataclass +from typing import Any, Protocol + +from dubbo.logger import LoggerAdapter + + +@dataclass +class ExtendedRegistry: + """ + A dataclass to represent an extended registry. + Attributes: + interface: Any -> The interface of the registry. + impls: dict[str, Any] -> A dict of implementations of the interface. -> {name: impl} + """ + + interface: Any + impls: dict[str, Any] + + +"""Protocol registry.""" +protocolRegistry = ExtendedRegistry( + interface=Protocol, + impls={ + "tri": "dubbo.protocol.triple.triple_protocol.TripleProtocol", + }, +) + +"""LoggerAdapter registry.""" +loggerAdapterRegistry = ExtendedRegistry( + interface=LoggerAdapter, + impls={ + "logging": "dubbo.logger.logging.logger_adapter.LoggingLoggerAdapter", + }, +) + + +def get_all_extended_registry() -> dict[Any, dict[str, Any]]: + """ + Get all extended registries in the current module. + :return: A dict of all extended registries. -> {interface: {name: impl}} + """ + current_module = sys.modules[__name__] + registries: dict[Any, dict[str, Any]] = {} + for name, obj in inspect.getmembers(current_module): + if isinstance(obj, ExtendedRegistry): + registries[obj.interface] = obj.impls + return registries diff --git a/dubbo/logger/__init__.py b/dubbo/logger/__init__.py index 5df0681..c7bee10 100644 --- a/dubbo/logger/__init__.py +++ b/dubbo/logger/__init__.py @@ -15,8 +15,3 @@ # limitations under the License. from .logger import Logger, LoggerAdapter -from .logger_factory import LoggerFactory as _LoggerFactory - -loggerFactory = _LoggerFactory - -__all__ = ["Logger", "LoggerAdapter", "loggerFactory"] diff --git a/dubbo/logger/logger.py b/dubbo/logger/logger.py index a0c7460..11f3595 100644 --- a/dubbo/logger/logger.py +++ b/dubbo/logger/logger.py @@ -15,7 +15,7 @@ # limitations under the License. from typing import Any -from dubbo.common.constants.logger import Level +from dubbo.common.constants.logger_constants import Level from dubbo.common.url import URL diff --git a/dubbo/logger/logger_factory.py b/dubbo/logger/logger_factory.py index 4b594ab..83024d4 100644 --- a/dubbo/logger/logger_factory.py +++ b/dubbo/logger/logger_factory.py @@ -16,13 +16,13 @@ import threading from typing import Dict -from dubbo.common.constants import logger as logger_constants -from dubbo.common.constants.logger import Level +from dubbo.common.constants import logger_constants as logger_constants +from dubbo.common.constants.logger_constants import Level from dubbo.common.url import URL -from dubbo.logger import Logger, LoggerAdapter -from dubbo.logger.internal.logger_adapter import InternalLoggerAdapter +from dubbo.logger.logger import Logger, LoggerAdapter +from dubbo.logger.logging.logger_adapter import LoggingLoggerAdapter -# Default config of InternalLoggerAdapter +# Default logger config with default values. _default_config = URL( protocol=logger_constants.DEFAULT_DRIVER_VALUE, host=logger_constants.DEFAULT_LEVEL_VALUE.value, @@ -39,16 +39,16 @@ ) -class LoggerFactory: +class _LoggerFactory: """ - Factory class to create loggers. + LoggerFactory Attributes: - _logger_adapter(LoggerAdapter): logger adapter. Default: InternalLoggerAdapter(_default_config) - _loggers(Dict[str, LoggerAdapter]): A dictionary to store all the loggers. - _loggers_lock(threading.Lock): The lock is used to lock all loggers when the logger adapter is changed. + _logger_adapter (LoggerAdapter): The logger adapter. + _loggers (Dict[str, Logger]): The logger cache. + _loggers_lock (threading.Lock): The logger lock to protect the logger cache. """ - _logger_adapter = InternalLoggerAdapter(_default_config) + _logger_adapter = LoggingLoggerAdapter(_default_config) _loggers: Dict[str, Logger] = {} _loggers_lock = threading.Lock() @@ -89,7 +89,7 @@ def get_logger(cls, name: str) -> Logger: Logger: An instance of the requested logger. """ logger = cls._loggers.get(name) - if logger is None: + if not logger: cls._loggers_lock.acquire() try: if name not in cls._loggers: @@ -97,7 +97,6 @@ def get_logger(cls, name: str) -> Logger: logger = cls._loggers[name] finally: cls._loggers_lock.release() - return logger @classmethod @@ -119,3 +118,6 @@ def set_level(cls, level: Level) -> None: level (Level): The logging level to set. """ cls._logger_adapter.level = level + + +loggerFactory = _LoggerFactory diff --git a/dubbo/logger/logging/__init__.py b/dubbo/logger/logging/__init__.py new file mode 100644 index 0000000..d8765ff --- /dev/null +++ b/dubbo/logger/logging/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .logger_adapter import LoggerAdapter diff --git a/dubbo/logger/logging/formatter.py b/dubbo/logger/logging/formatter.py new file mode 100644 index 0000000..56a002a --- /dev/null +++ b/dubbo/logger/logging/formatter.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re +from enum import Enum + + +class Colors(Enum): + """ + Colors for log messages. + """ + + END = "\033[0m" + BOLD = "\033[1m" + BLUE = "\033[34m" + GREEN = "\033[32m" + PURPLE = "\033[35m" + CYAN = "\033[36m" + RED = "\033[31m" + YELLOW = "\033[33m" + GREY = "\033[38;5;240m" + + +LEVEL_MAP = { + "DEBUG": Colors.BLUE.value, + "INFO": Colors.GREEN.value, + "WARNING": Colors.YELLOW.value, + "ERROR": Colors.RED.value, + "CRITICAL": Colors.RED.value + Colors.BOLD.value, +} + +DATE_FORMAT: str = "%Y-%m-%d %H:%M:%S" + +LOG_FORMAT: str = ( + f"{Colors.GREEN.value}%(asctime)s{Colors.END.value}" + " | " + f"%(level_color)s%(levelname)s{Colors.END.value}" + " | " + f"{Colors.CYAN.value}%(module)s:%(funcName)s:%(lineno)d{Colors.END.value}" + " - " + f"{Colors.PURPLE.value}[Dubbo]{Colors.END.value} " + f"%(msg_color)s%(message)s{Colors.END.value}" +) + + +class ColorFormatter(logging.Formatter): + """ + A formatter with color. + It will format the log message like this: + 2024-06-24 16:39:57 | DEBUG | test_logger_factory:test_with_config:44 - [Dubbo] debug log + """ + + def __init__(self): + self.log_format = LOG_FORMAT + super().__init__(self.log_format, DATE_FORMAT) + + def format(self, record) -> str: + levelname = record.levelname + record.level_color = record.msg_color = LEVEL_MAP.get(levelname) + return super().format(record) + + +class NoColorFormatter(logging.Formatter): + """ + A formatter without color. + It will format the log message like this: + 2024-06-24 16:39:57 | DEBUG | test_logger_factory:test_with_config:44 - [Dubbo] debug log + """ + + def __init__(self): + color_re = re.compile(r"\033\[[0-9;]*\w|%\((msg_color|level_color)\)s") + self.log_format = color_re.sub("", LOG_FORMAT) + super().__init__(self.log_format, DATE_FORMAT) diff --git a/dubbo/logger/internal/logger.py b/dubbo/logger/logging/logger.py similarity index 93% rename from dubbo/logger/internal/logger.py rename to dubbo/logger/logging/logger.py index 6e84a35..0a3887a 100644 --- a/dubbo/logger/internal/logger.py +++ b/dubbo/logger/logging/logger.py @@ -17,10 +17,10 @@ import logging from typing import Dict -from dubbo.common.constants.logger import Level +from dubbo.common.constants.logger_constants import Level from dubbo.logger import Logger -# The mapping from the logging level to the internal logging level. +# The mapping from the logging level to the logging level. _level_map: Dict[Level, int] = { Level.DEBUG: logging.DEBUG, Level.INFO: logging.INFO, @@ -31,9 +31,9 @@ } -class InternalLogger(Logger): +class LoggingLogger(Logger): """ - The internal logger implementation. + The logging logger implementation. Attributes: _logger (logging.Logger): The real working logger object """ diff --git a/dubbo/logger/internal/logger_adapter.py b/dubbo/logger/logging/logger_adapter.py similarity index 76% rename from dubbo/logger/internal/logger_adapter.py rename to dubbo/logger/logging/logger_adapter.py index b4ba560..e0ce6eb 100644 --- a/dubbo/logger/internal/logger_adapter.py +++ b/dubbo/logger/logging/logger_adapter.py @@ -16,32 +16,29 @@ import logging import os +import sys from functools import cache from logging import handlers -from dubbo.common import extension -from dubbo.common.constants import logger as logger_constants -from dubbo.common.constants.logger import FileRotateType, Level +from dubbo.common.constants import common_constants +from dubbo.common.constants import logger_constants as logger_constants +from dubbo.common.constants.logger_constants import FileRotateType, Level from dubbo.common.url import URL from dubbo.logger import Logger, LoggerAdapter -from dubbo.logger.internal.logger import InternalLogger +from dubbo.logger.logging import formatter +from dubbo.logger.logging.logger import LoggingLogger -"""This module provides the internal logger implementation. -> logging module""" +"""This module provides the logging logger implementation. -> logging module""" -_default_format = "%(asctime)s | %(levelname)s | %(module)s:%(funcName)s:%(lineno)d - [Dubbo] %(message)s" - -@extension.register_logger_adapter("logging") -class InternalLoggerAdapter(LoggerAdapter): +class LoggingLoggerAdapter(LoggerAdapter): """ - Internal logger adapter.Responsible for internal logger creation, encapsulated the logging.getLogger() method + Internal logger adapter.Responsible for logging logger creation, encapsulated the logging.getLogger() method Attributes: _level(Level): logging level. - _format(str): default logging format string. """ _level: Level - _format: str def __init__(self, config: URL): super().__init__(config) @@ -49,10 +46,6 @@ def __init__(self, config: URL): level_name = config.parameters.get(logger_constants.LEVEL_KEY) self._level = Level.get_level(level_name) if level_name else Level.DEBUG self._update_level() - # Set format - self._format = ( - config.parameters.get(logger_constants.FORMAT_KEY) or _default_format - ) def get_logger(self, name: str) -> Logger: """ @@ -68,18 +61,29 @@ def get_logger(self, name: str) -> Logger: parameters = self._config.parameters # Add console handler - if parameters.get(logger_constants.CONSOLE_ENABLED_KEY) == str(True): + if parameters.get( + logger_constants.CONSOLE_ENABLED_KEY, + logger_constants.DEFAULT_CONSOLE_ENABLED_VALUE, + ).lower() == common_constants.TRUE_VALUE or bool( + sys.stdout and sys.stdout.isatty() + ): logger_instance.addHandler(self._get_console_handler()) # Add file handler - if parameters.get(logger_constants.FILE_ENABLED_KEY) == str(True): + if ( + parameters.get( + logger_constants.FILE_ENABLED_KEY, + logger_constants.DEFAULT_FILE_ENABLED_VALUE, + ).lower() + == common_constants.TRUE_VALUE + ): logger_instance.addHandler(self._get_file_handler()) if not logger_instance.handlers: # It's intended to be used to avoid the "No handlers could be found for logger XXX" one-off warning. logger_instance.addHandler(logging.NullHandler()) - return InternalLogger(logger_instance) + return LoggingLogger(logger_instance) @cache def _get_console_handler(self) -> logging.StreamHandler: @@ -88,13 +92,8 @@ def _get_console_handler(self) -> logging.StreamHandler: Returns: logging.StreamHandler: The console handler. """ - parameters = self._config.parameters console_handler = logging.StreamHandler() - console_format = ( - parameters.get(logger_constants.CONSOLE_FORMAT_KEY) or self._format - ) - console_formatter = logging.Formatter(console_format) - console_handler.setFormatter(console_formatter) + console_handler.setFormatter(formatter.ColorFormatter()) return console_handler @@ -140,9 +139,7 @@ def _get_file_handler(self) -> logging.Handler: file_handler = logging.FileHandler(file_path) # Add file_handler - file_format = parameters.get(logger_constants.FILE_FORMAT_KEY) or self._format - file_formatter = logging.Formatter(file_format) - file_handler.setFormatter(file_formatter) + file_handler.setFormatter(formatter.NoColorFormatter()) return file_handler @property diff --git a/dubbo/loop/__init__.py b/dubbo/loop/__init__.py new file mode 100644 index 0000000..a7ebe86 --- /dev/null +++ b/dubbo/loop/__init__.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dubbo.loop.loop_manger import LoopManager as _LoopManager + + +def _try_use_uvloop() -> None: + """ + Use uvloop instead of the default asyncio loop. + """ + import asyncio + import os + + from dubbo.logger.logger_factory import loggerFactory + + logger = loggerFactory.get_logger("try_use_uvloop") + + # Check if the operating system. + if os.name == "nt": + # Windows is not supported. + logger.warning( + "Unable to use uvloop, because it is not supported on your operating system." + ) + return + + # Try import uvloop. + try: + import uvloop + except ImportError: + # uvloop is not available. + logger.warning( + "Unable to use uvloop, because it is not installed. " + "You can install it by running `pip install uvloop`." + ) + return + + # Use uvloop instead of the default asyncio loop. + if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy): + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +# Call the function to try to use uvloop. +_try_use_uvloop() + +loopManager = _LoopManager() diff --git a/dubbo/loop/loop_manger.py b/dubbo/loop/loop_manger.py new file mode 100644 index 0000000..825f2c7 --- /dev/null +++ b/dubbo/loop/loop_manger.py @@ -0,0 +1,111 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import threading +from typing import Optional + +from dubbo.logger.logger_factory import loggerFactory + +logger = loggerFactory.get_logger(__name__) + + +def start_loop(loop): + """ + Start the loop. + Args: + loop: The loop to start. + """ + asyncio.set_event_loop(loop) + loop.run_forever() + + +class LoopManager: + """ + Loop manager. + It used to manage the global event loop and therefore designed as a singleton pattern. + Attributes: + _instance: The instance of the loop manager. + _ins_lock: The lock to protect the instance. + _client_initialized: Whether the client is initialized. + _client_destroyed: Whether the client is destroyed. + _client_loop_info: The client info. (thread, loop) + _cli_lock: The lock to protect the client info. + """ + + _instance = None + _ins_lock = threading.Lock() + + # About client + _client_initialized = False + _client_destroyed = False + _client_loop_info = None + _cli_lock = threading.Lock() + + def __new__(cls): + if cls._instance is None: + with cls._ins_lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def _init_client_loop(self): + """ + Initialize the client loop. + return: The client info. (thread, loop) + """ + new_loop = asyncio.new_event_loop() + # Start the loop in a new thread + thread = threading.Thread( + target=start_loop, args=(new_loop,), name="dubbo-client-loop", daemon=True + ) + thread.start() + self._client_loop_info = (thread, new_loop) + self._client_initialized = True + logger.info("The client loop is initialized.") + return self._client_loop_info + + def get_client_loop(self) -> Optional[asyncio.AbstractEventLoop]: + """ + Get the client loop. Lazy initialization. + return: If the client is destroyed, return None. Otherwise, return the client loop. + """ + if self._client_destroyed: + logger.error("The client is destroyed.") + return None + + if not self._client_initialized: + with self._cli_lock: + if not self._client_initialized: + self._init_client_loop() + return self._client_loop_info[1] + + def destroy_client_loop(self) -> None: + """ + Destroy the client. This method can only be called once. + """ + if self._client_destroyed: + logger.info("The client is already destroyed.") + return + + with self._cli_lock: + if not self._client_destroyed: + client_loop_info = self._client_loop_info + # Stop the loop + client_loop_info[1].stop() + # Wait for the loop to stop + client_loop_info[0].join() + self._client_destroyed = True + logger.info("The client is destroyed.") diff --git a/dubbo/protocol/__init__.py b/dubbo/protocol/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/protocol/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/protocol/invocation.py b/dubbo/protocol/invocation.py new file mode 100644 index 0000000..4e4a7f6 --- /dev/null +++ b/dubbo/protocol/invocation.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + + +class Invocation: + + def get_service_name(self) -> str: + """ + Get the service name. + """ + raise NotImplementedError("get_service_name() is not implemented.") + + def get_method_name(self) -> str: + """ + Get the method name. + """ + raise NotImplementedError("get_method_name() is not implemented.") + + def get_argument(self) -> Any: + """ + Get the method argument. + """ + raise NotImplementedError("get_args() is not implemented.") + + +class RpcInvocation(Invocation): + """ + The RpcInvocation class is an implementation of the Invocation interface. + Args: + service_name (str): The name of the service. + method_name (str): The name of the method. + argument (Any): The method argument. + req_serializer (Any): The request serializer. + res_serializer (Any): The response serializer. + """ + + def __init__( + self, + service_name: str, + method_name: str, + argument: Any, + req_serializer=None, + res_serializer=None, + ): + self._service_name = service_name + self._method_name = method_name + self._argument = argument + self._req_serializer = req_serializer + self._res_serializer = res_serializer + + def get_service_name(self): + return self._service_name + + def get_method_name(self): + return self._method_name + + def get_argument(self): + return self._argument + + def get_req_serializer(self): + return self._req_serializer + + def get_res_serializer(self): + return self._res_serializer diff --git a/dubbo/protocol/invoker.py b/dubbo/protocol/invoker.py new file mode 100644 index 0000000..8d5b64d --- /dev/null +++ b/dubbo/protocol/invoker.py @@ -0,0 +1,35 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dubbo.common.node import Node +from dubbo.protocol.invocation import Invocation +from dubbo.protocol.result import Result + + +class Invoker(Node): + + def get_interface(self): + """ + Get service interface. + """ + raise NotImplementedError("get_interface() is not implemented.") + + def invoke(self, invocation: Invocation) -> Result: + """ + Invoke the service. + Returns: + Result: The result of the invocation. + """ + raise NotImplementedError("invoke() is not implemented.") diff --git a/dubbo/protocol/protocol.py b/dubbo/protocol/protocol.py new file mode 100644 index 0000000..5ae08a0 --- /dev/null +++ b/dubbo/protocol/protocol.py @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dubbo.common.url import URL +from dubbo.protocol.invoker import Invoker + + +class Protocol: + + def refer(self, url: URL) -> Invoker: + """ + Refer a remote service. + Args: + url (URL): The URL of the remote service. + Returns: + Invoker: The invoker of the remote service. + """ + raise NotImplementedError("refer() is not implemented.") diff --git a/dubbo/protocol/result.py b/dubbo/protocol/result.py new file mode 100644 index 0000000..06b54e1 --- /dev/null +++ b/dubbo/protocol/result.py @@ -0,0 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Result: + pass diff --git a/dubbo/protocol/triple/__init__.py b/dubbo/protocol/triple/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/protocol/triple/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/protocol/triple/tri_decoder.py b/dubbo/protocol/triple/tri_decoder.py new file mode 100644 index 0000000..3defcbd --- /dev/null +++ b/dubbo/protocol/triple/tri_decoder.py @@ -0,0 +1,152 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum + +from dubbo.compressor.compressor import DeCompressor + + +class GrpcDecodeState(enum.Enum): + """ + gRPC Decode State + """ + + HEADER = 0 + PAYLOAD = 1 + + +class TriDecoder: + """ + This class is responsible for decoding the gRPC message format, which is composed of a header and payload. + gRPC Message Format Diagram + + +----------------------+-------------------------+------------------+ + | HTTP Header | gRPC Header | Business Data | + +----------------------+-------------------------+------------------+ + | (variable length) | type (1 byte) | data (variable) | + | | compressed-flag (1 byte)| | + | | message length (4 byte) | | + +----------------------+-------------------------+------------------+ + + Args: + decompressor (DeCompressor): The decompressor to use for decompressing the payload. + listener (TriDecoder.Listener): The listener to deliver the decoded payload to. + + """ + + HEADER_LENGTH: int = 5 + COMPRESSED_FLAG_MASK: int = 1 + RESERVED_MASK: int = 0xFE + + def __init__(self, decompressor: DeCompressor, listener: "TriDecoder.Listener"): + self.accumulate = bytearray() + self._decompressor = decompressor + self._listener = listener + self.state = GrpcDecodeState.HEADER + self.required_length = self.HEADER_LENGTH + self.compressed = False + self.in_delivery = False + self.closing = False + self.closed = False + + def deframe(self, data: bytes): + """ + Process the incoming bytes, deframing the gRPC message and delivering the payload to the listener. + """ + self.accumulate.extend(data) + self._deliver() + + def close(self): + """ + Close the decoder and listener. + """ + self.closing = True + self._deliver() + + def _deliver(self): + """ + Deliver the accumulated bytes to the listener, processing the header and payload as necessary. + """ + if self.in_delivery: + return + + self.in_delivery = True + try: + while self._has_enough_bytes(): + if self.state == GrpcDecodeState.HEADER: + self._process_header() + elif self.state == GrpcDecodeState.PAYLOAD: + self._process_payload() + if self.closing: + if not self.closed: + self.closed = True + self.accumulate = None + self._listener.close() + finally: + self.in_delivery = False + + def _has_enough_bytes(self): + """ + Check if the accumulated bytes are enough to process the header or payload + """ + return len(self.accumulate) >= self.required_length + + def _process_header(self): + """ + Processes the GRPC compression header which is composed of the compression flag and the outer frame length. + """ + header_bytes = self.accumulate[: self.required_length] + self.accumulate = self.accumulate[self.required_length :] + + type_byte = header_bytes[0] + + if type_byte & self.RESERVED_MASK: + raise ValueError("gRPC frame header malformed: reserved bits not zero") + + self.compressed = bool(type_byte & self.COMPRESSED_FLAG_MASK) + self.required_length = int.from_bytes(header_bytes[1:], byteorder="big") + + # Continue to process the payload + self.state = GrpcDecodeState.PAYLOAD + + def _process_payload(self): + """ + Processes the GRPC message body, which depending on frame header flags may be compressed. + """ + payload_bytes = self.accumulate[: self.required_length] + self.accumulate = self.accumulate[self.required_length :] + + if self.compressed: + # Decompress the payload + payload_bytes = self._decompressor.decompress(payload_bytes) + + self._listener.on_message(payload_bytes) + + # Done with this frame, begin processing the next header. + self.required_length = self.HEADER_LENGTH + self.state = GrpcDecodeState.HEADER + + class Listener: + def on_message(self, message: bytes): + """ + Called when a message is received. + """ + raise NotImplementedError("Listener.on_message() not implemented") + + def close(self): + """ + Called when the listener is closed. + """ + raise NotImplementedError("Listener.close() not implemented") diff --git a/dubbo/protocol/triple/tri_invoker.py b/dubbo/protocol/triple/tri_invoker.py new file mode 100644 index 0000000..d2730a8 --- /dev/null +++ b/dubbo/protocol/triple/tri_invoker.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dubbo.common.url import URL +from dubbo.protocol.invocation import Invocation +from dubbo.protocol.invoker import Invoker +from dubbo.protocol.result import Result + + +class TripleInvoker(Invoker): + + def __init__(self, url: URL): + self.url = url + + def invoke(self, invocation: Invocation) -> Result: + pass + + def get_url(self) -> URL: + return self.url + + def is_available(self) -> bool: + pass + + def destroy(self) -> None: + pass diff --git a/dubbo/protocol/triple/tri_stream.py b/dubbo/protocol/triple/tri_stream.py new file mode 100644 index 0000000..aeb5ada --- /dev/null +++ b/dubbo/protocol/triple/tri_stream.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Tuple + + +class Stream: + """ + Stream is a bi-directional channel that manipulates the data flow between peers. + Inbound data from remote peer is acquired by Stream.Listener. + Outbound data to remote peer is sent directly by Stream. + """ + + def send_headers(self, headers: List[Tuple[str, str]]) -> None: + """ + Send the headers frame + Args: + headers: The headers to send. + """ + raise NotImplementedError("send_headers() is not implemented") + + def send_data(self, stream_id: int, data: bytes, end_stream: bool = False) -> None: + """ + Send the data frame + Args: + stream_id: The stream ID the data is associated with. + data: The data to send. + end_stream: Whether to end the stream. + """ + raise NotImplementedError("send_data() is not implemented") + + class Listener: + """ + Listener is the interface to receive the data flow from the remote peer + """ + + def receive_headers( + self, stream_id: int, headers: List[Tuple[str, str]] + ) -> None: + """ + Called when the header frame is received + Args: + stream_id: The stream ID the headers are associated with. + headers: The headers received. + """ + raise NotImplementedError("receive_headers() is not implemented") + + def receive_data(self, stream_id: int, data: bytes) -> None: + """ + Called when the data frame is received + Args: + stream_id: The stream ID the data is associated with. + data: The data received. + """ + raise NotImplementedError("receive_data() is not implemented") + + def receive_trailers( + self, stream_id: int, headers: List[Tuple[str, str]] + ) -> None: + """ + Called when the trailers frame is received + Args: + stream_id: The stream ID the trailers are associated with. + headers: The trailers received. + """ + raise NotImplementedError("receive_trailers() is not implemented") + + def receive_end(self, stream_id: int) -> None: + """ + Called when the stream is ended + Args: + stream_id: The stream ID that was ended. + """ + raise NotImplementedError("receive_end() is not implemented") diff --git a/dubbo/protocol/triple/triple_protocol.py b/dubbo/protocol/triple/triple_protocol.py new file mode 100644 index 0000000..445ffef --- /dev/null +++ b/dubbo/protocol/triple/triple_protocol.py @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dubbo.common.url import URL +from dubbo.logger.logger_factory import loggerFactory +from dubbo.protocol.invoker import Invoker +from dubbo.protocol.protocol import Protocol + +logger = loggerFactory.get_logger(__name__) + + +class TripleProtocol(Protocol): + + def refer(self, url: URL) -> Invoker: + + pass diff --git a/dubbo/remoting/__init__.py b/dubbo/remoting/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/remoting/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/remoting/aio/__init__.py b/dubbo/remoting/aio/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/remoting/aio/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/remoting/aio/aio_transporter.py b/dubbo/remoting/aio/aio_transporter.py new file mode 100644 index 0000000..882223f --- /dev/null +++ b/dubbo/remoting/aio/aio_transporter.py @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio + +from h2.config import H2Configuration + +from dubbo.common.url import URL +from dubbo.logger.logger_factory import loggerFactory +from dubbo.loop import loopManager +from dubbo.remoting.aio.http2_protocol import Http2Protocol +from dubbo.remoting.transporter import (RemotingClient, RemotingServer, + Transporter) + +logger = loggerFactory.get_logger(__name__) + + +class AioTransporter(Transporter): + """ + Asyncio transporter. + """ + + def bind(self, url: URL) -> RemotingServer: + return AioServer() + + def connect(self, url: URL) -> RemotingClient: + return AioClient(url) + + +class AioClient(RemotingClient): + """ + Asyncio client. + """ + def __init__(self, url: URL): + self.url = url + self._protocol = None + self._transport = None + self._loop = loopManager.get_client_loop() + + self._closed = False + + async def _create_connect(self): + transport, protocol = await self._loop.create_connection( + lambda: Http2Protocol( + H2Configuration(client_side=True, header_encoding="utf-8") + ), + self.url.host, + self.url.port if self.url.port else None, + ) + return transport, protocol + + def start(self): + future = asyncio.run_coroutine_threadsafe(self._create_connect(), self._loop) + try: + self._transport, self._protocol = future.result() + except Exception: + logger.exception("Failed to create connection.") + self._transport = None + self._protocol = None + + def is_available(self) -> bool: + if self._closed: + return False + return self._transport and not self._transport.is_closing() + + async def send(self, data: bytes): + self._protocol.send_data(data) + + async def close(self): + self._closed = True + self._transport.close() + await self._transport.wait_closed() + + +class AioServer(RemotingServer): + """ + Asyncio server. + """ + pass diff --git a/dubbo/remoting/aio/http2_protocol.py b/dubbo/remoting/aio/http2_protocol.py new file mode 100644 index 0000000..76dfa99 --- /dev/null +++ b/dubbo/remoting/aio/http2_protocol.py @@ -0,0 +1,165 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from typing import List, Optional, Tuple + +from h2.config import H2Configuration +from h2.connection import H2Connection +from h2.events import (DataReceived, RemoteSettingsChanged, RequestReceived, + ResponseReceived, StreamEnded, TrailersReceived, + WindowUpdated) +from h2.exceptions import ProtocolError, StreamClosedError +from h2.settings import SettingCodes + +from dubbo.logger.logger_factory import loggerFactory + +logger = loggerFactory.get_logger(__name__) + + +class Http2Protocol(asyncio.Protocol): + + def __init__(self, h2_config: H2Configuration): + h2_config.logger = logger + self.conn = H2Connection(config=h2_config) + self.transport = None + self.flow_control_futures = {} + + def connection_made(self, transport: asyncio.Transport) -> None: + self.transport = transport + self.conn.initiate_connection() + + def connection_lost(self, exc: Exception) -> None: + if exc: + logger.error(f"Connection lost: {exc}") + else: + logger.info("Connection closed cleanly.") + self.transport.close() + + async def send_headers( + self, + headers: List[Tuple[str, str]], + stream_id: Optional[int] = None, + end_stream=False, + ) -> int: + """ + Send headers to the server or client. + Args: + headers: A list of header tuples. + stream_id: The stream ID to send the headers on. If None, a new stream will be created. + end_stream: Whether to close the stream after sending the headers. + Returns: + The stream ID the headers were sent on. + """ + if not stream_id: + # Get the next available stream ID. + stream_id = self.conn.get_next_available_stream_id() + self.conn.send_headers(stream_id, headers, end_stream=end_stream) + self.transport.write(self.conn.data_to_send()) + return stream_id + + async def send_data(self, stream_id: int, data: bytes, end_stream=False) -> None: + """ + Send data according to the flow control rules. + Args: + stream_id: The stream ID to send the data on. + data: The data to send. + end_stream: Whether to close the stream after sending the data. + """ + while data: + # Check the flow control window. + while self.conn.local_flow_control_window(stream_id) < 1: + try: + # Wait for flow control window to open. + await self.wait_for_flow_control(stream_id) + except asyncio.CancelledError: + return + # Determine how much data to send. + chunk_size = min( + self.conn.local_flow_control_window(stream_id), + len(data), + self.conn.max_outbound_frame_size, + ) + try: + # Send the data. + self.conn.send_data( + stream_id, + data[:chunk_size], + end_stream=(chunk_size == len(data) and end_stream), + ) + except (StreamClosedError, ProtocolError): + logger.error( + f"Stream {stream_id} closed unexpectedly, aborting data send." + ) + break + + self.transport.write(self.conn.data_to_send()) + data = data[chunk_size:] + + def data_received(self, data: bytes) -> None: + try: + # Parse the received data. + events = self.conn.receive_data(data) + + if not events: + self.transport.write(self.conn.data_to_send()) + else: + # Process the events. + for event in events: + if isinstance(event, ResponseReceived) or isinstance( + event, RequestReceived + ): + self.receive_headers(event.stream_id, event.headers) + elif isinstance(event, DataReceived): + self.receive_data(event.stream_id, event.data) + elif isinstance(event, TrailersReceived): + self.receive_trailers(event.stream_id, event.headers) + elif isinstance(event, StreamEnded): + self.receive_end(event.stream_id) + elif isinstance(event, WindowUpdated): + self.window_updated(event.stream_id, event.delta) + elif isinstance(event, RemoteSettingsChanged): + if SettingCodes.INITIAL_WINDOW_SIZE in event.changed_settings: + self.window_updated(None, 0) + + data = self.conn.data_to_send() + if data: + self.transport.write(data) + + except ProtocolError: + logger.exception("Parse HTTP2 frame error") + self.transport.write(self.conn.data_to_send()) + self.transport.close() + + async def wait_for_flow_control(self, stream_id) -> None: + """ + Waits for a Future that fires when the flow control window is opened. + """ + f = asyncio.Future() + self.flow_control_futures[stream_id] = f + await f + + def window_updated(self, stream_id, delta) -> None: + """ + A window update frame was received. Unblock some number of flow control Futures. + """ + if stream_id and stream_id in self.flow_control_futures: + future = self.flow_control_futures.pop(stream_id) + future.set_result(delta) + else: + # If it does not match, remove all flow control. + for f in self.flow_control_futures.values(): + f.set_result(delta) + self.flow_control_futures.clear() diff --git a/dubbo/remoting/transporter.py b/dubbo/remoting/transporter.py new file mode 100644 index 0000000..48c9f43 --- /dev/null +++ b/dubbo/remoting/transporter.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dubbo.common.url import URL + + +class RemotingServer: + + pass + + +class RemotingClient: + + pass + + +class Transporter: + def bind(self, url: URL) -> RemotingServer: + """ + Bind a server. + """ + pass + + def connect(self, url: URL) -> RemotingClient: + """ + Connect to a server. + """ + pass diff --git a/dubbo/serialization/__init__.py b/dubbo/serialization/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/dubbo/serialization/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dubbo/serialization/serialization.py b/dubbo/serialization/serialization.py new file mode 100644 index 0000000..937267b --- /dev/null +++ b/dubbo/serialization/serialization.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +from dubbo.common.constants import common_constants +from dubbo.common.url import URL +from dubbo.logger import logger_factory + +logger = logger_factory.get_logger(__name__) + + +def serialize(method: str, url: URL, *args, **kwargs) -> bytes: + """ + Serialize the given data + Args: + method(str): The method to serialize + url(URL): URL + *args: Variable length argument list + **kwargs: Arbitrary keyword arguments + Returns: + bytes: The serialized data + Exception: If the serialization fails + """ + # get the serializer + method_dict = url.get_attribute(method) or {} + serializer = method_dict.get(common_constants.SERIALIZATION) + # serialize the data + if serializer: + try: + return serializer(*args, **kwargs) + except Exception as e: + logger.exception( + "Serialization send error, please check the incoming serialization function" + ) + raise e + else: + # check if the data is bytes -> args[0] + if isinstance(args[0], bytes): + return args[0] + else: + err_msg = "The args[0] is not bytes, you should pass parameters of type bytes, or set the serialization function" + logger.error(err_msg) + raise ValueError(err_msg) + + +def deserialize(method: str, url: URL, data: bytes) -> Any: + """ + Deserialize the given data + Args: + method(str): The method to deserialize + url(URL): URL + data(bytes): The data to deserialize + Returns: + Any: The deserialized data + Exception: If the deserialization fails + """ + # get the deserializer + method_dict = url.get_attribute(method) or {} + deserializer = method_dict.get(common_constants.DESERIALIZATION) + # deserialize the data + if not deserializer: + return data + else: + try: + return deserializer(data) + except Exception as e: + logger.exception( + "Deserialization send error, please check the incoming deserialization function" + ) + raise e diff --git a/requirements.txt b/requirements.txt index e69de29..b782d68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1 @@ +h2~=4.1.0 \ No newline at end of file diff --git a/tests/logger/test_logger_factory.py b/tests/logger/test_logger_factory.py index c33204a..fa3016a 100644 --- a/tests/logger/test_logger_factory.py +++ b/tests/logger/test_logger_factory.py @@ -15,11 +15,11 @@ # limitations under the License. import unittest -from dubbo.common.constants import logger as logger_constants -from dubbo.common.constants.logger import Level +from dubbo.common.constants import logger_constants as logger_constants +from dubbo.common.constants.logger_constants import Level from dubbo.config import LoggerConfig -from dubbo.logger import loggerFactory -from dubbo.logger.internal.logger_adapter import InternalLoggerAdapter +from dubbo.logger.logger_factory import loggerFactory +from dubbo.logger.logging.logger_adapter import LoggingLoggerAdapter class TestLoggerFactory(unittest.TestCase): @@ -31,18 +31,19 @@ def test_without_config(self): def test_with_config(self): # Test the case where config is used - config = LoggerConfig("logging") + config = LoggerConfig.default_config() config.init() logger = loggerFactory.get_logger("test_factory") logger.info("info log -> with_config ") url = config.get_url() url.add_parameter(logger_constants.FILE_ENABLED_KEY, True) - loggerFactory.set_logger_adapter(InternalLoggerAdapter(url)) + loggerFactory.set_logger_adapter(LoggingLoggerAdapter(url)) loggerFactory.set_level(Level.DEBUG) + logger = loggerFactory.get_logger("test_factory") logger.debug("debug log -> with_config -> open file") url.add_parameter(logger_constants.CONSOLE_ENABLED_KEY, False) - loggerFactory.set_logger_adapter(InternalLoggerAdapter(url)) + loggerFactory.set_logger_adapter(LoggingLoggerAdapter(url)) loggerFactory.set_level(Level.DEBUG) logger.debug("debug log -> with_config -> lose console") diff --git a/tests/logger/test_internal_logger.py b/tests/logger/test_logging_logger.py similarity index 87% rename from tests/logger/test_internal_logger.py rename to tests/logger/test_logging_logger.py index 91fbbb5..c95a9ab 100644 --- a/tests/logger/test_internal_logger.py +++ b/tests/logger/test_logging_logger.py @@ -15,15 +15,17 @@ # limitations under the License. import unittest -from dubbo.common.constants.logger import Level +from dubbo.common.constants.logger_constants import Level from dubbo.config import LoggerConfig -from dubbo.logger.internal.logger_adapter import InternalLoggerAdapter +from dubbo.logger.logging.logger_adapter import LoggingLoggerAdapter class TestInternalLogger(unittest.TestCase): def test_log(self): - logger_adapter = InternalLoggerAdapter(config=LoggerConfig("logging").get_url()) + logger_adapter = LoggingLoggerAdapter( + config=LoggerConfig.default_config().get_url() + ) logger = logger_adapter.get_logger("test") logger.log(Level.INFO, "test log") logger.debug("test debug") diff --git a/tests/loop/__init__.py b/tests/loop/__init__.py new file mode 100644 index 0000000..bcba37a --- /dev/null +++ b/tests/loop/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/loop/test_loop_manger.py b/tests/loop/test_loop_manger.py new file mode 100644 index 0000000..835b92c --- /dev/null +++ b/tests/loop/test_loop_manger.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import unittest + +from dubbo.loop.loop_manger import LoopManager + + +async def _loop_task(): + while True: + print("loop task") + await asyncio.sleep(1) + + +class TestLoopManager(unittest.TestCase): + + def test_use_client(self): + loop_manager = LoopManager() + loop = loop_manager.get_client_loop() + asyncio.run_coroutine_threadsafe(_loop_task(), loop) + print("loop task started, waiting for 3 seconds...") + asyncio.run(asyncio.sleep(3)) + loop_manager.destroy_client_loop() + print("loop task stopped.") diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..b703b83 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,81 @@ +import asyncio +import concurrent.futures + + +# 定义异步 TCP 客户端任务 +class EchoClientProtocol(asyncio.Protocol): + def __init__(self, message, loop, on_con_lost): + self.message = message + self.loop = loop + self.on_con_lost = on_con_lost + + def connection_made(self, transport): + self.transport = transport + self.transport.write(self.message.encode()) + print("Data sent:", self.message) + + def data_received(self, data): + print("Data received:", data.decode()) + self.transport.close() + + def connection_lost(self, exc): + print("The server closed the connection") + self.on_con_lost.set_result(True) + + +async def tcp_client(loop, message, host, port): + on_con_lost = loop.create_future() + transport, protocol = await loop.create_connection( + lambda: EchoClientProtocol(message, loop, on_con_lost), host, port + ) + try: + await on_con_lost + finally: + transport.close() + + +def start_loop(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + +def main(): + host = "127.0.0.1" + port = 8888 + + # 使用线程池管理事件循环线程 + with concurrent.futures.ThreadPoolExecutor() as executor: + new_loop = asyncio.new_event_loop() + executor.submit(start_loop, new_loop) + + # 创建并提交 TCP 客户端任务到线程池中的事件循环 + future1 = asyncio.run_coroutine_threadsafe( + tcp_client(new_loop, "Message for server 1", host, port), new_loop + ) + future2 = asyncio.run_coroutine_threadsafe( + tcp_client(new_loop, "Message for server 2", host, port), new_loop + ) + future3 = asyncio.run_coroutine_threadsafe( + tcp_client(new_loop, "Message for server 3", host, port), new_loop + ) + + # 使用返回的 Future 对象来监视和管理协程任务 + print("Waiting for tasks to complete...") + for future in [future1, future2, future3]: + try: + result = future.result() # 获取协程的结果(阻塞直到结果可用) + print(f"Task completed with result: {result}") + except Exception as e: + print(f"Task raised an exception: {e}") + + # 等待一段时间以观察任务执行 + import time + + time.sleep(10) # 根据需要调整等待时间 + + print("结束事件循环") + new_loop.call_soon_threadsafe(new_loop.stop) # 优雅停止事件循环 + + +if __name__ == "__main__": + main() diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..df1de7e --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + + +class EchoServerProtocol(asyncio.Protocol): + def connection_made(self, transport): + self.transport = transport + print("Connection from", transport.get_extra_info("peername")) + + def data_received(self, data): + message = data.decode() + print("Data received:", message) + self.transport.write(data) # Echo the received data back + + def connection_lost(self, exc): + print("Client disconnected") + + +async def run_server(): + loop = asyncio.get_running_loop() + server = await loop.create_server(lambda: EchoServerProtocol(), "127.0.0.1", 8888) + async with server: + await server.serve_forever() + + +if __name__ == "__main__": + asyncio.run(run_server()) From 2fb6d89ef9d879e3a80b0c6f381e12e0d9f3cea6 Mon Sep 17 00:00:00 2001 From: zaki Date: Sat, 29 Jun 2024 13:42:32 +0800 Subject: [PATCH 23/32] fix: fix ci --- tests/common/extension/__init__.py | 15 -------- .../common/extension/test_logger_extension.py | 36 ------------------- 2 files changed, 51 deletions(-) delete mode 100644 tests/common/extension/__init__.py delete mode 100644 tests/common/extension/test_logger_extension.py diff --git a/tests/common/extension/__init__.py b/tests/common/extension/__init__.py deleted file mode 100644 index bcba37a..0000000 --- a/tests/common/extension/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/common/extension/test_logger_extension.py b/tests/common/extension/test_logger_extension.py deleted file mode 100644 index 350be07..0000000 --- a/tests/common/extension/test_logger_extension.py +++ /dev/null @@ -1,36 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import unittest - -from dubbo.common import extension -from dubbo.config import LoggerConfig - - -class TestLoggerExtension(unittest.TestCase): - - def test_logger_extension(self): - - # Test the get_logger_adapter method. - logger_adapter = extension.get_logger_adapter( - "logging", LoggerConfig("logging").get_url() - ) - - # Test logger_adapter methods. - logger = logger_adapter.get_logger("test") - logger.debug("test debug") - logger.info("test info") - logger.warning("test warning") - logger.error("test error") From c4f8d52ab10743b16504cd60a7f088dbb97804d7 Mon Sep 17 00:00:00 2001 From: zaki Date: Sat, 29 Jun 2024 13:47:29 +0800 Subject: [PATCH 24/32] fix: Delete some invalid files --- tests/test_client.py | 81 -------------------------------------------- tests/test_server.py | 43 ----------------------- 2 files changed, 124 deletions(-) delete mode 100644 tests/test_client.py delete mode 100644 tests/test_server.py diff --git a/tests/test_client.py b/tests/test_client.py deleted file mode 100644 index b703b83..0000000 --- a/tests/test_client.py +++ /dev/null @@ -1,81 +0,0 @@ -import asyncio -import concurrent.futures - - -# 定义异步 TCP 客户端任务 -class EchoClientProtocol(asyncio.Protocol): - def __init__(self, message, loop, on_con_lost): - self.message = message - self.loop = loop - self.on_con_lost = on_con_lost - - def connection_made(self, transport): - self.transport = transport - self.transport.write(self.message.encode()) - print("Data sent:", self.message) - - def data_received(self, data): - print("Data received:", data.decode()) - self.transport.close() - - def connection_lost(self, exc): - print("The server closed the connection") - self.on_con_lost.set_result(True) - - -async def tcp_client(loop, message, host, port): - on_con_lost = loop.create_future() - transport, protocol = await loop.create_connection( - lambda: EchoClientProtocol(message, loop, on_con_lost), host, port - ) - try: - await on_con_lost - finally: - transport.close() - - -def start_loop(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - -def main(): - host = "127.0.0.1" - port = 8888 - - # 使用线程池管理事件循环线程 - with concurrent.futures.ThreadPoolExecutor() as executor: - new_loop = asyncio.new_event_loop() - executor.submit(start_loop, new_loop) - - # 创建并提交 TCP 客户端任务到线程池中的事件循环 - future1 = asyncio.run_coroutine_threadsafe( - tcp_client(new_loop, "Message for server 1", host, port), new_loop - ) - future2 = asyncio.run_coroutine_threadsafe( - tcp_client(new_loop, "Message for server 2", host, port), new_loop - ) - future3 = asyncio.run_coroutine_threadsafe( - tcp_client(new_loop, "Message for server 3", host, port), new_loop - ) - - # 使用返回的 Future 对象来监视和管理协程任务 - print("Waiting for tasks to complete...") - for future in [future1, future2, future3]: - try: - result = future.result() # 获取协程的结果(阻塞直到结果可用) - print(f"Task completed with result: {result}") - except Exception as e: - print(f"Task raised an exception: {e}") - - # 等待一段时间以观察任务执行 - import time - - time.sleep(10) # 根据需要调整等待时间 - - print("结束事件循环") - new_loop.call_soon_threadsafe(new_loop.stop) # 优雅停止事件循环 - - -if __name__ == "__main__": - main() diff --git a/tests/test_server.py b/tests/test_server.py deleted file mode 100644 index df1de7e..0000000 --- a/tests/test_server.py +++ /dev/null @@ -1,43 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import asyncio - - -class EchoServerProtocol(asyncio.Protocol): - def connection_made(self, transport): - self.transport = transport - print("Connection from", transport.get_extra_info("peername")) - - def data_received(self, data): - message = data.decode() - print("Data received:", message) - self.transport.write(data) # Echo the received data back - - def connection_lost(self, exc): - print("Client disconnected") - - -async def run_server(): - loop = asyncio.get_running_loop() - server = await loop.create_server(lambda: EchoServerProtocol(), "127.0.0.1", 8888) - async with server: - await server.serve_forever() - - -if __name__ == "__main__": - asyncio.run(run_server()) From 952541d7f15d3d4e691845144bbda9e4d976b5dc Mon Sep 17 00:00:00 2001 From: zaki Date: Mon, 1 Jul 2024 19:49:45 +0800 Subject: [PATCH 25/32] feat: Complete the network transmission part --- dubbo/_dubbo.py | 3 +- dubbo/callable/rpc_callable.py | 9 +- dubbo/client/client.py | 6 +- dubbo/extension/__init__.py | 3 +- .../triple/{tri_stream.py => stream.py} | 77 +++- dubbo/remoting/aio/aio_stream.py | 208 ++++++++++ dubbo/remoting/aio/aio_transporter.py | 52 +-- .../__init__.py => remoting/aio/constants.py} | 3 + dubbo/remoting/aio/http2_protocol.py | 386 +++++++++++++----- dubbo/{serialization => }/serialization.py | 4 +- requirements.txt | 3 +- 11 files changed, 560 insertions(+), 194 deletions(-) rename dubbo/protocol/triple/{tri_stream.py => stream.py} (58%) create mode 100644 dubbo/remoting/aio/aio_stream.py rename dubbo/{serialization/__init__.py => remoting/aio/constants.py} (91%) rename dubbo/{serialization => }/serialization.py (96%) diff --git a/dubbo/_dubbo.py b/dubbo/_dubbo.py index 05a096f..fece509 100644 --- a/dubbo/_dubbo.py +++ b/dubbo/_dubbo.py @@ -16,8 +16,7 @@ import threading from typing import Dict, List -from dubbo.config import (ApplicationConfig, ConsumerConfig, LoggerConfig, - ProtocolConfig) +from dubbo.config import ApplicationConfig, ConsumerConfig, LoggerConfig, ProtocolConfig from dubbo.logger.logger_factory import loggerFactory logger = loggerFactory.get_logger(__name__) diff --git a/dubbo/callable/rpc_callable.py b/dubbo/callable/rpc_callable.py index 5f6405c..9171e1f 100644 --- a/dubbo/callable/rpc_callable.py +++ b/dubbo/callable/rpc_callable.py @@ -38,7 +38,7 @@ def __init__(self, invoker: Invoker, url: URL): method_url.get_attribute(common_constants.SERIALIZATION) or None ) - def _do_call(self, argument: Any): + async def _do_call(self, argument: Any): """ Real call method. """ @@ -66,10 +66,11 @@ def _do_call(self, argument: Any): self._res_serializer, ) # Do invoke. - return self._invoker.invoke(invocation) + result = self._invoker.invoke(invocation) + return result - def __call__(self, argument: Any): - return self._do_call(argument) + async def __call__(self, argument: Any): + return await self._do_call(argument) class AsyncRpcCallable: diff --git a/dubbo/client/client.py b/dubbo/client/client.py index f66a523..f929029 100644 --- a/dubbo/client/client.py +++ b/dubbo/client/client.py @@ -18,8 +18,10 @@ from dubbo.callable.rpc_callable import AsyncRpcCallable, RpcCallable from dubbo.callable.rpc_callable_factory import DefaultRpcCallableFactory from dubbo.common.constants import common_constants -from dubbo.common.constants.type_constants import (DeserializingFunction, - SerializingFunction) +from dubbo.common.constants.type_constants import ( + DeserializingFunction, + SerializingFunction, +) from dubbo.common.url import URL from dubbo.config import ConsumerConfig, ReferenceConfig from dubbo.logger.logger_factory import loggerFactory diff --git a/dubbo/extension/__init__.py b/dubbo/extension/__init__.py index 8744a34..0da2118 100644 --- a/dubbo/extension/__init__.py +++ b/dubbo/extension/__init__.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dubbo.extension.extension_loader import \ - ExtensionLoader as _ExtensionLoader +from dubbo.extension.extension_loader import ExtensionLoader as _ExtensionLoader extensionLoader = _ExtensionLoader() diff --git a/dubbo/protocol/triple/tri_stream.py b/dubbo/protocol/triple/stream.py similarity index 58% rename from dubbo/protocol/triple/tri_stream.py rename to dubbo/protocol/triple/stream.py index aeb5ada..65264c1 100644 --- a/dubbo/protocol/triple/tri_stream.py +++ b/dubbo/protocol/triple/stream.py @@ -23,64 +23,97 @@ class Stream: Outbound data to remote peer is sent directly by Stream. """ + def __init__(self, stream_id: int): + self._stream_id = stream_id + def send_headers(self, headers: List[Tuple[str, str]]) -> None: """ - Send the headers frame + First call: head frame + Second call: trailer frame. Args: headers: The headers to send. """ raise NotImplementedError("send_headers() is not implemented") - def send_data(self, stream_id: int, data: bytes, end_stream: bool = False) -> None: + def send_data(self, data: bytes) -> None: """ Send the data frame Args: - stream_id: The stream ID the data is associated with. data: The data to send. - end_stream: Whether to end the stream. """ raise NotImplementedError("send_data() is not implemented") + def send_end_stream(self) -> None: + """ + Send the end stream frame -> An empty data frame will be sent (end_stream=True) + """ + raise NotImplementedError("send_completed() is not implemented") + class Listener: """ - Listener is the interface to receive the data flow from the remote peer + Listener is the interface that receives the data from the stream. """ - def receive_headers( - self, stream_id: int, headers: List[Tuple[str, str]] - ) -> None: + def on_headers(self, headers: List[Tuple[str, str]]) -> None: """ Called when the header frame is received Args: - stream_id: The stream ID the headers are associated with. headers: The headers received. """ raise NotImplementedError("receive_headers() is not implemented") - def receive_data(self, stream_id: int, data: bytes) -> None: + def on_data(self, data: bytes) -> None: """ Called when the data frame is received Args: - stream_id: The stream ID the data is associated with. data: The data received. """ raise NotImplementedError("receive_data() is not implemented") - def receive_trailers( - self, stream_id: int, headers: List[Tuple[str, str]] - ) -> None: + def on_complete(self) -> None: + """ + Complete the stream. + """ + raise NotImplementedError("complete() is not implemented") + + +class ClientStream(Stream): + """ + ClientStream is a Stream that is initiated by the client. + """ + + pass + + class Listener(Stream.Listener): + """ + Listener is the interface that receives the data from the stream. + """ + + def on_trailers(self, headers: List[Tuple[str, str]]) -> None: """ Called when the trailers frame is received Args: - stream_id: The stream ID the trailers are associated with. headers: The trailers received. """ raise NotImplementedError("receive_trailers() is not implemented") - def receive_end(self, stream_id: int) -> None: - """ - Called when the stream is ended - Args: - stream_id: The stream ID that was ended. - """ - raise NotImplementedError("receive_end() is not implemented") + +class ServerStream(Stream): + """ + ServerStream is a Stream that is initiated by the server. + """ + + def send_trailers(self, trailers: List[Tuple[str, str]]) -> None: + """ + Send the trailers frame + Args: + trailers: The trailers to send. + """ + raise NotImplementedError("send_trailers() is not implemented") + + class Listener(Stream.Listener): + """ + Listener is the interface that receives the data from the stream. + """ + + pass diff --git a/dubbo/remoting/aio/aio_stream.py b/dubbo/remoting/aio/aio_stream.py new file mode 100644 index 0000000..de708be --- /dev/null +++ b/dubbo/remoting/aio/aio_stream.py @@ -0,0 +1,208 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from typing import List, Optional, Tuple + +from dubbo.logger.logger_factory import loggerFactory +from dubbo.protocol.triple.stream import ClientStream, ServerStream, Stream +from dubbo.remoting.aio.constants import END_DATA_SENTINEL + +logger = loggerFactory.get_logger(__name__) + +HEADER_FRAME = "HEADER_FRAME" +DATA_FRAME = "DATA_FRAME" +TRAILER_FRAME = "TRAILER_FRAME" + + +class AioStream(Stream): + """ + The Stream object for HTTP/2 + """ + + def __init__(self, stream_id: int, loop, protocol): + super().__init__(stream_id) + # The loop to run the asynchronous function. + self._loop = loop + # The protocol to send the frame. + self._protocol = protocol + + # The flag to indicate whether the header has been sent. + self._header_emitted = False + # This is an event that send a header frame. + # It is used to ensure that the header frame is sent before the data frame. + self._send_header_event: Optional[asyncio.Event] = None + + # The queue to store the all frames to send. It is used to ensure the order of the frames. + self._write_queue = asyncio.Queue() + # This is an event that send a data frame. + # It is used to ensure that the data frame is sent before the next data frame. + self._send_data_event: Optional[asyncio.Event] = None + + # The task to send the frames. + self._send_loop_task = self._loop.create_task(self._send_loop()) + + # The flag to indicate whether the sending is completed. + # However, it does not mean that all the data has been sent successfully, + # but is only used to prevent other data from being sent. + self._send_completed = False + + # The flag to indicate whether the receiving is completed. + self._receive_completed = False + + def send_headers(self, headers: List[Tuple[str, str]]) -> None: + """ + The first call sends the head frame, the second call sends the trailer frame. + Args: + headers: The headers to send. + """ + if self._send_completed: + raise RuntimeError("The stream has finished sending data") + + if self._header_emitted: + # If the header has been sent, it means that the trailer is being sent. + self._send_completed = True + else: + self._header_emitted = True + + def _inner_send_headers(headers, end_stream): + data_type = TRAILER_FRAME if end_stream else HEADER_FRAME + self._write_queue.put_nowait((data_type, headers)) + + self._loop.call_soon_threadsafe( + _inner_send_headers, headers, self._send_completed + ) + + def send_data(self, data: bytes) -> None: + """ + Send the data frame. + Args: + data: The data to send. + """ + if self._send_completed: + raise RuntimeError("The stream has finished sending data") + elif not self._header_emitted: + raise RuntimeError("The header has not been sent") + + def _inner_send_data(data): + self._write_queue.put_nowait((DATA_FRAME, data)) + + self._loop.call_soon_threadsafe(_inner_send_data, data) + + def send_end_stream(self) -> None: + """ + Send the end stream frame -> An empty data frame will be sent (end_stream=True) + """ + + def _inner_send_end_stream(): + self._write_queue.put_nowait((DATA_FRAME, END_DATA_SENTINEL)) + + self._loop.call_soon_threadsafe(_inner_send_end_stream) + + async def _send_loop(self): + """ + Asynchronous blocking to get data from write_queue and send it. + The purpose of using write_queue is to ensure that frames are sent in the following order: + 1. HEADER_FRAME + 2. DATA_FRAME (0 or more) + 3. TRAILER_FRAME (optional) + The format of the queue elements is: (type, data) -> (HEADER_FRAME, [("key", "value")]) or (DATA_FRAME, b"") + """ + while True: + data_type, data = await self._write_queue.get() + + if data_type == HEADER_FRAME: + # If the data is a header frame, send it directly. + self._send_header_event = self._protocol.send_head_frame( + self._stream_id, data + ) + continue + + # Waiting for the headers to be sent + assert self._send_header_event is not None + await self._send_header_event.wait() + + if self._send_data_event: + # Waiting for the previous message to be sent + await self._send_data_event.wait() + + if data_type == DATA_FRAME and data: + self._send_data_event = self._protocol.send_data_frame( + self._stream_id, data + ) + if data == END_DATA_SENTINEL: + # If it is an END_DATA_SENTINEL, it means that the data has been sent. + break + elif data_type == TRAILER_FRAME: + # If it is a TRAILER_FRAME, then it must also be a last frame, + # so it exits the loop when it finishes sending. + self._protocol.send_head_frame(self._stream_id, data, end_stream=True) + break + + +class AioClientStream(AioStream, ClientStream): + """ + The Stream object for the HTTP/2. (client side) + """ + + def __init__(self, loop, protocol, listener: ClientStream.Listener): + super().__init__(protocol.conn.get_next_available_stream_id(), loop, protocol) + self._protocol.register_stream(self._stream_id, self) + + # receive data + self._listener = listener + + def receive_headers(self, headers: List[Tuple[str, str]]) -> None: + """ + Receive the headers. + """ + # Running synchronized functions non-blocking + self._loop.run_in_executor(None, self._listener.on_headers, headers) + + def receive_data(self, data: bytes) -> None: + """ + Receive the data. + """ + self._loop.run_in_executor(None, self._listener.on_data, data) + + def receive_trailers(self, trailers: List[Tuple[str, str]]) -> None: + """ + Receive the trailers. + """ + self._loop.run_in_executor(None, self._listener.on_trailers, trailers) + + def receive_complete(self): + self._receive_completed = True + + +class AioServerStream(AioStream, ServerStream): + """ + The Stream object for the HTTP/2. (server side) + """ + + def __init__(self, stream_id, loop, protocol): + super().__init__(stream_id, loop, protocol) + + def receive_headers(self, headers: List[Tuple[str, str]]) -> None: + pass + + def receive_data(self, data: bytes) -> None: + pass + + def receive_trailers(self, trailers: List[Tuple[str, str]]) -> None: + pass + + def receive_complete(self): + self._receive_completed = True diff --git a/dubbo/remoting/aio/aio_transporter.py b/dubbo/remoting/aio/aio_transporter.py index 882223f..d684434 100644 --- a/dubbo/remoting/aio/aio_transporter.py +++ b/dubbo/remoting/aio/aio_transporter.py @@ -13,16 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import asyncio - -from h2.config import H2Configuration from dubbo.common.url import URL from dubbo.logger.logger_factory import loggerFactory -from dubbo.loop import loopManager -from dubbo.remoting.aio.http2_protocol import Http2Protocol -from dubbo.remoting.transporter import (RemotingClient, RemotingServer, - Transporter) +from dubbo.remoting.transporter import RemotingClient, RemotingServer, Transporter logger = loggerFactory.get_logger(__name__) @@ -33,59 +27,23 @@ class AioTransporter(Transporter): """ def bind(self, url: URL) -> RemotingServer: - return AioServer() + pass def connect(self, url: URL) -> RemotingClient: - return AioClient(url) + pass class AioClient(RemotingClient): """ Asyncio client. """ - def __init__(self, url: URL): - self.url = url - self._protocol = None - self._transport = None - self._loop = loopManager.get_client_loop() - - self._closed = False - - async def _create_connect(self): - transport, protocol = await self._loop.create_connection( - lambda: Http2Protocol( - H2Configuration(client_side=True, header_encoding="utf-8") - ), - self.url.host, - self.url.port if self.url.port else None, - ) - return transport, protocol - - def start(self): - future = asyncio.run_coroutine_threadsafe(self._create_connect(), self._loop) - try: - self._transport, self._protocol = future.result() - except Exception: - logger.exception("Failed to create connection.") - self._transport = None - self._protocol = None - def is_available(self) -> bool: - if self._closed: - return False - return self._transport and not self._transport.is_closing() - - async def send(self, data: bytes): - self._protocol.send_data(data) - - async def close(self): - self._closed = True - self._transport.close() - await self._transport.wait_closed() + pass class AioServer(RemotingServer): """ Asyncio server. """ + pass diff --git a/dubbo/serialization/__init__.py b/dubbo/remoting/aio/constants.py similarity index 91% rename from dubbo/serialization/__init__.py rename to dubbo/remoting/aio/constants.py index bcba37a..cbcc52c 100644 --- a/dubbo/serialization/__init__.py +++ b/dubbo/remoting/aio/constants.py @@ -13,3 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# Used to indicate the end of the data. +END_DATA_SENTINEL = object() diff --git a/dubbo/remoting/aio/http2_protocol.py b/dubbo/remoting/aio/http2_protocol.py index 76dfa99..cd5e064 100644 --- a/dubbo/remoting/aio/http2_protocol.py +++ b/dubbo/remoting/aio/http2_protocol.py @@ -16,150 +16,312 @@ import asyncio from typing import List, Optional, Tuple +import h2.events from h2.config import H2Configuration from h2.connection import H2Connection -from h2.events import (DataReceived, RemoteSettingsChanged, RequestReceived, - ResponseReceived, StreamEnded, TrailersReceived, - WindowUpdated) -from h2.exceptions import ProtocolError, StreamClosedError -from h2.settings import SettingCodes +from h2.events import ( + DataReceived, + PingReceived, + RemoteSettingsChanged, + RequestReceived, + ResponseReceived, + StreamEnded, + StreamReset, + TrailersReceived, + WindowUpdated, +) from dubbo.logger.logger_factory import loggerFactory +from dubbo.remoting.aio.constants import END_DATA_SENTINEL logger = loggerFactory.get_logger(__name__) -class Http2Protocol(asyncio.Protocol): +class HTTP2Protocol(asyncio.Protocol): def __init__(self, h2_config: H2Configuration): - h2_config.logger = logger - self.conn = H2Connection(config=h2_config) - self.transport = None - self.flow_control_futures = {} + # Create the H2 state machine + self.conn: H2Connection = H2Connection(config=h2_config) + + # the backing transport. + self.transport: Optional[asyncio.Transport] = None + + # The asyncio event loop. + self._loop = asyncio.get_running_loop() + + # A mapping of stream ID to stream object. + self.streams = {} + + # The `write_data_queue`, `flow_controlled_data`, and `send_data_loop_task` together form the flow control mechanism. + # Data flows between `write_queue` and `flow_controlled_data`. + # The `send_data_loop_task` blocks while reading data from the `write_queue` and attempts to send it. + # If a flow control limit is encountered, the unsent data is stored in `flow_controlled_data`, + # awaiting a WINDOW_UPDATE frame, at which point it is moved back from `flow_controlled_data` to `write_queue`. + self._write_data_queue = asyncio.Queue() + self._flow_controlled_data = {} + self._send_data_loop_task = None + + # Any streams that have been remotely reset. + self._reset_streams = set() def connection_made(self, transport: asyncio.Transport) -> None: + """ + Called when the connection is first established. We complete the following actions: + 1. Save the transport. + 2. Initialize the H2 connection. + 3. Create the send data loop task. + """ self.transport = transport self.conn.initiate_connection() + self.transport.write(self.conn.data_to_send()) + self._send_data_loop_task = self._loop.create_task(self._send_data_loop()) - def connection_lost(self, exc: Exception) -> None: - if exc: - logger.error(f"Connection lost: {exc}") - else: - logger.info("Connection closed cleanly.") - self.transport.close() + def connection_lost(self, exc) -> None: + """ + Called when the connection is lost. + """ + self._send_data_loop_task.cancel() - async def send_headers( + def send_head_frame( self, + stream_id: int, headers: List[Tuple[str, str]], - stream_id: Optional[int] = None, end_stream=False, - ) -> int: + head_event: Optional[asyncio.Event] = None, + ) -> asyncio.Event: """ - Send headers to the server or client. - Args: - headers: A list of header tuples. - stream_id: The stream ID to send the headers on. If None, a new stream will be created. - end_stream: Whether to close the stream after sending the headers. - Returns: - The stream ID the headers were sent on. - """ - if not stream_id: - # Get the next available stream ID. - stream_id = self.conn.get_next_available_stream_id() - self.conn.send_headers(stream_id, headers, end_stream=end_stream) - self.transport.write(self.conn.data_to_send()) - return stream_id + Send headers to the remote peer. + Because flow control is only for data frames, we can directly send the head frame rate. + Note: Only the first call sends a head frame, if called again, a trailer frame is sent. + """ + head_event = head_event or asyncio.Event() + + def _inner_send_header_frame(stream_id, headers, event): + self.conn.send_headers(stream_id, headers, end_stream) + self.transport.write(self.conn.data_to_send()) + event.set() - async def send_data(self, stream_id: int, data: bytes, end_stream=False) -> None: + # Send the header frame + self._loop.call_soon_threadsafe( + _inner_send_header_frame, stream_id, headers, head_event + ) + + return head_event + + def send_data_frame(self, stream_id: int, data) -> asyncio.Event: """ - Send data according to the flow control rules. + Send data to the remote peer. + The sending of data frames is subject to traffic control, + so we put them in a queue and send them according to traffic control rules Args: - stream_id: The stream ID to send the data on. - data: The data to send. - end_stream: Whether to close the stream after sending the data. - """ - while data: - # Check the flow control window. - while self.conn.local_flow_control_window(stream_id) < 1: - try: - # Wait for flow control window to open. - await self.wait_for_flow_control(stream_id) - except asyncio.CancelledError: - return - # Determine how much data to send. - chunk_size = min( - self.conn.local_flow_control_window(stream_id), - len(data), - self.conn.max_outbound_frame_size, - ) - try: - # Send the data. - self.conn.send_data( - stream_id, - data[:chunk_size], - end_stream=(chunk_size == len(data) and end_stream), - ) - except (StreamClosedError, ProtocolError): - logger.error( - f"Stream {stream_id} closed unexpectedly, aborting data send." - ) - break + stream_id: stream id + data: data + """ + event = asyncio.Event() - self.transport.write(self.conn.data_to_send()) - data = data[chunk_size:] + def _inner_send_data_frame(stream_id: int, data, event: asyncio.Event): + self._write_data_queue.put_nowait((stream_id, data, event)) - def data_received(self, data: bytes) -> None: - try: - # Parse the received data. - events = self.conn.receive_data(data) + self._loop.call_soon_threadsafe(_inner_send_data_frame, stream_id, data, event) + + return event + + async def _send_data_loop(self) -> None: + """ + Asynchronous blocking to get data from write_data_queue and try to send it, + this method implements the flow control mechanism + """ + while True: + stream_id, data, event = await self._write_data_queue.get() + + # If this stream got reset, just drop the data on the floor. + if stream_id in self._reset_streams: + event.set() + continue + + if data is END_DATA_SENTINEL: + self.conn.end_stream(stream_id) + self.transport.write(self.conn.data_to_send()) + event.set() + continue - if not events: + # We need to send data, but not to exceed the flow control window. + window_size = self.conn.local_flow_control_window(stream_id) + chunk_size = min(window_size, len(data)) + data_to_send = data[:chunk_size] + data_to_buffer = data[chunk_size:] + + if data_to_send: + # Send the data frame + max_size = self.conn.max_outbound_frame_size + chunks = ( + data_to_send[x : x + max_size] + for x in range(0, len(data_to_send), max_size) + ) + for chunk in chunks: + self.conn.send_data(stream_id, chunk) self.transport.write(self.conn.data_to_send()) + + if data_to_buffer: + # We still have data to send, but it's blocked by traffic control, + # so we need to wait for the traffic window to open again. + self._flow_controlled_data[stream_id] = ( + stream_id, + data_to_buffer, + event, + ) else: - # Process the events. - for event in events: - if isinstance(event, ResponseReceived) or isinstance( - event, RequestReceived - ): - self.receive_headers(event.stream_id, event.headers) - elif isinstance(event, DataReceived): - self.receive_data(event.stream_id, event.data) - elif isinstance(event, TrailersReceived): - self.receive_trailers(event.stream_id, event.headers) - elif isinstance(event, StreamEnded): - self.receive_end(event.stream_id) - elif isinstance(event, WindowUpdated): - self.window_updated(event.stream_id, event.delta) - elif isinstance(event, RemoteSettingsChanged): - if SettingCodes.INITIAL_WINDOW_SIZE in event.changed_settings: - self.window_updated(None, 0) - - data = self.conn.data_to_send() - if data: - self.transport.write(data) - - except ProtocolError: - logger.exception("Parse HTTP2 frame error") - self.transport.write(self.conn.data_to_send()) - self.transport.close() + # We sent everything. + event.set() - async def wait_for_flow_control(self, stream_id) -> None: + def data_received(self, data: bytes) -> None: """ - Waits for a Future that fires when the flow control window is opened. + Process inbound data. """ - f = asyncio.Future() - self.flow_control_futures[stream_id] = f - await f + events = self.conn.receive_data(data) + for event in events: + self._process_event(event) + outbound_data = self.conn.data_to_send() + if outbound_data: + self.transport.write(outbound_data) - def window_updated(self, stream_id, delta) -> None: + def _process_event(self, event: h2.events.Event) -> Optional[bool]: """ - A window update frame was received. Unblock some number of flow control Futures. + Process an event. """ - if stream_id and stream_id in self.flow_control_futures: - future = self.flow_control_futures.pop(stream_id) - future.set_result(delta) + if isinstance(event, (RemoteSettingsChanged, PingReceived)): + # Events that are handled automatically by the H2 library. + # 1. RemoteSettingsChanged: h2 automatically acknowledges settings changes + # 2. PingReceived: A ping acknowledgment with the same opaque data is automatically emitted after receiving a ping. + pass + elif isinstance(event, WindowUpdated): + self.window_updated(event) + elif isinstance(event, StreamReset): + self.reset_stream(event) else: - # If it does not match, remove all flow control. - for f in self.flow_control_futures.values(): - f.set_result(delta) - self.flow_control_futures.clear() + # A False here means that the current event is not handled and needs to be handled by the subclass. + return False + + def window_updated(self, event: WindowUpdated) -> None: + """ + The flow control window got opened. + + """ + if event.stream_id: + # This is specific to a single stream. + if event.stream_id in self._flow_controlled_data: + self._write_data_queue.put_nowait( + self._flow_controlled_data.pop(event.stream_id) + ) + else: + # This event is specific to the connection. + # Free up all the streams. + for data in self._flow_controlled_data.values(): + self._write_data_queue.put_nowait(data) + + self._flow_controlled_data = {} + + def reset_stream(self, event: StreamReset) -> None: + """ + The remote peer reset the stream. + """ + if event.stream_id in self._flow_controlled_data: + del self._flow_controlled_data + + self._reset_streams.add(event.stream_id) + + +class HTTP2ClientProtocol(HTTP2Protocol): + """ + An HTTP/2 client protocol. + """ + + def __init__(self): + h2_config = H2Configuration(client_side=True, header_encoding="utf-8") + super().__init__(h2_config) + + def register_stream(self, stream_id, stream): + self.streams[stream_id] = stream + + def _process_event(self, event): + if super()._process_event(event) is False: + if isinstance(event, ResponseReceived): + self.receive_headers(event) + elif isinstance(event, DataReceived): + self.receive_data(event) + elif isinstance(event, TrailersReceived): + self.receive_trailers(event) + elif isinstance(event, StreamEnded): + self.stream_ended(event) + + def receive_headers(self, event: ResponseReceived): + """ + The response headers have been received. + """ + self.streams[event.stream_id].receive_headers(event.headers) + + def receive_data(self, event: DataReceived): + """ + Data has been received. + """ + self.streams[event.stream_id].receive_data(event.data) + # Acknowledge the data, so the remote peer can send more. + self.conn.acknowledge_received_data( + event.flow_controlled_length, event.stream_id + ) + + def receive_trailers(self, event): + """ + Trailers have been received. + """ + self.streams[event.stream_id].receive_trailers(event.headers) + + def stream_ended(self, event): + """ + The stream has ended. + """ + self.streams[event.stream_id].receive_complete() + # Clean up the stream. + del self.streams[event.stream_id] + + def reset_stream(self, event: StreamReset) -> None: + super().reset_stream(event) + # TODO Pass the exception to the corresponding stream object + + +class HTTP2ServerProtocol(HTTP2Protocol): + + def __init__(self): + h2_config = H2Configuration(client_side=False, header_encoding="utf-8") + super().__init__(h2_config) + + def _process_event(self, event: h2.events.Event): + if super()._process_event(event) is False: + if isinstance(event, RequestReceived): + self.receive_headers(event) + elif isinstance(event, DataReceived): + self.receive_data(event) + elif isinstance(event, StreamEnded): + self.stream_ended(event) + + def receive_headers(self, event: RequestReceived): + """ + The request headers have been received. + """ + from dubbo.remoting.aio.aio_stream import AioServerStream + + s = AioServerStream(event.stream_id, self._loop, self) + self.streams[event.stream_id] = s + s.receive_headers(event.headers) + + def receive_data(self, event: DataReceived): + """ + Data has been received. + """ + self.streams[event.stream_id].receive_data(event.data) + + def stream_ended(self, event: StreamEnded): + """ + The stream has ended. + """ + self.streams[event.stream_id].receive_complete() diff --git a/dubbo/serialization/serialization.py b/dubbo/serialization.py similarity index 96% rename from dubbo/serialization/serialization.py rename to dubbo/serialization.py index 937267b..2049eb1 100644 --- a/dubbo/serialization/serialization.py +++ b/dubbo/serialization.py @@ -17,9 +17,9 @@ from dubbo.common.constants import common_constants from dubbo.common.url import URL -from dubbo.logger import logger_factory +from dubbo.logger.logger_factory import loggerFactory -logger = logger_factory.get_logger(__name__) +logger = loggerFactory.get_logger(__name__) def serialize(method: str, url: URL, *args, **kwargs) -> bytes: diff --git a/requirements.txt b/requirements.txt index b782d68..97fc58d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -h2~=4.1.0 \ No newline at end of file +h2~=4.1.0 +uvloop~=0.19.0 \ No newline at end of file From dd83710167368442a4cb10911cd8ac986c21034f Mon Sep 17 00:00:00 2001 From: zaki Date: Sat, 6 Jul 2024 00:47:42 +0800 Subject: [PATCH 26/32] perf: Optimization of the network transmission part --- dubbo/remoting/aio/aio_stream.py | 208 -------------- dubbo/remoting/aio/h2_frame.py | 247 ++++++++++++++++ dubbo/remoting/aio/h2_protocol.py | 341 ++++++++++++++++++++++ dubbo/remoting/aio/h2_stream.py | 366 ++++++++++++++++++++++++ dubbo/remoting/aio/h2_stream_handler.py | 169 +++++++++++ dubbo/remoting/aio/http2_protocol.py | 327 --------------------- 6 files changed, 1123 insertions(+), 535 deletions(-) delete mode 100644 dubbo/remoting/aio/aio_stream.py create mode 100644 dubbo/remoting/aio/h2_frame.py create mode 100644 dubbo/remoting/aio/h2_protocol.py create mode 100644 dubbo/remoting/aio/h2_stream.py create mode 100644 dubbo/remoting/aio/h2_stream_handler.py delete mode 100644 dubbo/remoting/aio/http2_protocol.py diff --git a/dubbo/remoting/aio/aio_stream.py b/dubbo/remoting/aio/aio_stream.py deleted file mode 100644 index de708be..0000000 --- a/dubbo/remoting/aio/aio_stream.py +++ /dev/null @@ -1,208 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -from typing import List, Optional, Tuple - -from dubbo.logger.logger_factory import loggerFactory -from dubbo.protocol.triple.stream import ClientStream, ServerStream, Stream -from dubbo.remoting.aio.constants import END_DATA_SENTINEL - -logger = loggerFactory.get_logger(__name__) - -HEADER_FRAME = "HEADER_FRAME" -DATA_FRAME = "DATA_FRAME" -TRAILER_FRAME = "TRAILER_FRAME" - - -class AioStream(Stream): - """ - The Stream object for HTTP/2 - """ - - def __init__(self, stream_id: int, loop, protocol): - super().__init__(stream_id) - # The loop to run the asynchronous function. - self._loop = loop - # The protocol to send the frame. - self._protocol = protocol - - # The flag to indicate whether the header has been sent. - self._header_emitted = False - # This is an event that send a header frame. - # It is used to ensure that the header frame is sent before the data frame. - self._send_header_event: Optional[asyncio.Event] = None - - # The queue to store the all frames to send. It is used to ensure the order of the frames. - self._write_queue = asyncio.Queue() - # This is an event that send a data frame. - # It is used to ensure that the data frame is sent before the next data frame. - self._send_data_event: Optional[asyncio.Event] = None - - # The task to send the frames. - self._send_loop_task = self._loop.create_task(self._send_loop()) - - # The flag to indicate whether the sending is completed. - # However, it does not mean that all the data has been sent successfully, - # but is only used to prevent other data from being sent. - self._send_completed = False - - # The flag to indicate whether the receiving is completed. - self._receive_completed = False - - def send_headers(self, headers: List[Tuple[str, str]]) -> None: - """ - The first call sends the head frame, the second call sends the trailer frame. - Args: - headers: The headers to send. - """ - if self._send_completed: - raise RuntimeError("The stream has finished sending data") - - if self._header_emitted: - # If the header has been sent, it means that the trailer is being sent. - self._send_completed = True - else: - self._header_emitted = True - - def _inner_send_headers(headers, end_stream): - data_type = TRAILER_FRAME if end_stream else HEADER_FRAME - self._write_queue.put_nowait((data_type, headers)) - - self._loop.call_soon_threadsafe( - _inner_send_headers, headers, self._send_completed - ) - - def send_data(self, data: bytes) -> None: - """ - Send the data frame. - Args: - data: The data to send. - """ - if self._send_completed: - raise RuntimeError("The stream has finished sending data") - elif not self._header_emitted: - raise RuntimeError("The header has not been sent") - - def _inner_send_data(data): - self._write_queue.put_nowait((DATA_FRAME, data)) - - self._loop.call_soon_threadsafe(_inner_send_data, data) - - def send_end_stream(self) -> None: - """ - Send the end stream frame -> An empty data frame will be sent (end_stream=True) - """ - - def _inner_send_end_stream(): - self._write_queue.put_nowait((DATA_FRAME, END_DATA_SENTINEL)) - - self._loop.call_soon_threadsafe(_inner_send_end_stream) - - async def _send_loop(self): - """ - Asynchronous blocking to get data from write_queue and send it. - The purpose of using write_queue is to ensure that frames are sent in the following order: - 1. HEADER_FRAME - 2. DATA_FRAME (0 or more) - 3. TRAILER_FRAME (optional) - The format of the queue elements is: (type, data) -> (HEADER_FRAME, [("key", "value")]) or (DATA_FRAME, b"") - """ - while True: - data_type, data = await self._write_queue.get() - - if data_type == HEADER_FRAME: - # If the data is a header frame, send it directly. - self._send_header_event = self._protocol.send_head_frame( - self._stream_id, data - ) - continue - - # Waiting for the headers to be sent - assert self._send_header_event is not None - await self._send_header_event.wait() - - if self._send_data_event: - # Waiting for the previous message to be sent - await self._send_data_event.wait() - - if data_type == DATA_FRAME and data: - self._send_data_event = self._protocol.send_data_frame( - self._stream_id, data - ) - if data == END_DATA_SENTINEL: - # If it is an END_DATA_SENTINEL, it means that the data has been sent. - break - elif data_type == TRAILER_FRAME: - # If it is a TRAILER_FRAME, then it must also be a last frame, - # so it exits the loop when it finishes sending. - self._protocol.send_head_frame(self._stream_id, data, end_stream=True) - break - - -class AioClientStream(AioStream, ClientStream): - """ - The Stream object for the HTTP/2. (client side) - """ - - def __init__(self, loop, protocol, listener: ClientStream.Listener): - super().__init__(protocol.conn.get_next_available_stream_id(), loop, protocol) - self._protocol.register_stream(self._stream_id, self) - - # receive data - self._listener = listener - - def receive_headers(self, headers: List[Tuple[str, str]]) -> None: - """ - Receive the headers. - """ - # Running synchronized functions non-blocking - self._loop.run_in_executor(None, self._listener.on_headers, headers) - - def receive_data(self, data: bytes) -> None: - """ - Receive the data. - """ - self._loop.run_in_executor(None, self._listener.on_data, data) - - def receive_trailers(self, trailers: List[Tuple[str, str]]) -> None: - """ - Receive the trailers. - """ - self._loop.run_in_executor(None, self._listener.on_trailers, trailers) - - def receive_complete(self): - self._receive_completed = True - - -class AioServerStream(AioStream, ServerStream): - """ - The Stream object for the HTTP/2. (server side) - """ - - def __init__(self, stream_id, loop, protocol): - super().__init__(stream_id, loop, protocol) - - def receive_headers(self, headers: List[Tuple[str, str]]) -> None: - pass - - def receive_data(self, data: bytes) -> None: - pass - - def receive_trailers(self, trailers: List[Tuple[str, str]]) -> None: - pass - - def receive_complete(self): - self._receive_completed = True diff --git a/dubbo/remoting/aio/h2_frame.py b/dubbo/remoting/aio/h2_frame.py new file mode 100644 index 0000000..af3f0d5 --- /dev/null +++ b/dubbo/remoting/aio/h2_frame.py @@ -0,0 +1,247 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum +import sys +import time +from typing import Any, Dict, Optional + +from h2.events import ( + DataReceived, + Event, + RequestReceived, + ResponseReceived, + StreamReset, + TrailersReceived, + WindowUpdated, +) + + +class H2FrameType(enum.Enum): + """ + Enum class representing HTTP/2 frame types. + """ + + # Data frame, carries HTTP message bodies. + DATA = 0x0 + # Headers frame, carries HTTP headers. + HEADERS = 0x1 + # Priority frame, specifies the priority of a stream. + PRIORITY = 0x2 + # Reset Stream frame, cancels a stream. + RST_STREAM = 0x3 + # Settings frame, exchanges configuration parameters. + SETTINGS = 0x4 + # Push Promise frame, used by the server to push resources. + PUSH_PROMISE = 0x5 + # Ping frame, measures round-trip time and checks connectivity. + PING = 0x6 + # Goaway frame, signals that the connection will be closed. + GOAWAY = 0x7 + # Window Update frame, manages flow control window size. + WINDOW_UPDATE = 0x8 + # Continuation frame, transmits large header blocks. + CONTINUATION = 0x9 + + +class H2Frame: + """ + HTTP/2 frame class. It is used to represent an HTTP/2 frame. + Args: + stream_id: The stream identifier. + frame_type: The frame type. + data: The data to send. such as: HEADERS: List[Tuple[str, str]], DATA: bytes, END_STREAM: None or bytes. + end_stream: Whether the stream is ended. + attributes: The attributes of the frame. + """ + + def __init__( + self, + stream_id: int, + frame_type: H2FrameType, + data: Any = None, + end_stream: bool = False, + attributes: Optional[Dict[str, Any]] = None, + ): + self._stream_id = stream_id + self._frame_type = frame_type + self._data = data + self._end_stream = end_stream + self._attributes = attributes or {} + + # The timestamp of the generated frame. -> comparison for Priority Queue + self._timestamp = int(round(time.time() * 1000)) + + @property + def stream_id(self) -> int: + return self._stream_id + + @property + def frame_type(self) -> H2FrameType: + return self._frame_type + + @property + def data(self) -> Any: + return self._data + + @data.setter + def data(self, data: Any) -> None: + self._data = data + + @property + def end_stream(self) -> bool: + return self._end_stream + + @property + def attributes(self) -> Dict[str, Any]: + return self._attributes + + def __lt__(self, other: "H2Frame") -> bool: + return self._timestamp < other._timestamp + + def __str__(self): + return ( + f"H2Frame(stream_id={self.stream_id}, " + f"frame_type={self.frame_type}, " + f"data={self.data}, " + f"end_stream={self.end_stream}, " + f"attributes={self.attributes})" + ) + + +DATA_COMPLETED_FRAME: H2Frame = H2Frame(0, H2FrameType.DATA, b"") +# Make use of the infinity timestamp to ensure that the DATA_COMPLETED_FRAME is always at the end of the data queue. +DATA_COMPLETED_FRAME._timestamp = sys.maxsize + + +class H2FrameUtils: + """ + Utility class for creating HTTP/2 frames. + """ + + @staticmethod + def create_headers_frame( + stream_id: int, + headers: list[tuple[str, str]], + end_stream: bool = False, + attributes: Optional[Dict[str, str]] = None, + ) -> H2Frame: + """ + Create a headers frame. + Args: + stream_id: The stream identifier. + headers: The headers to send. + end_stream: Whether the stream is ended. + attributes: The attributes of the frame. + Returns: + The headers frame. + """ + return H2Frame(stream_id, H2FrameType.HEADERS, headers, end_stream, attributes) + + @staticmethod + def create_data_frame( + stream_id: int, + data: bytes, + end_stream: bool = False, + attributes: Optional[Dict[str, str]] = None, + ) -> H2Frame: + """ + Create a data frame. + Args: + stream_id: The stream identifier. + data: The data to send. + end_stream: Whether the stream is ended. + attributes: The attributes of the frame. + Returns: + The data frame. + """ + return H2Frame(stream_id, H2FrameType.DATA, data, end_stream, attributes) + + @staticmethod + def create_reset_stream_frame( + stream_id: int, + error_code: int, + attributes: Optional[Dict[str, str]] = None, + ) -> H2Frame: + """ + Create a reset stream frame. + Args: + stream_id: The stream identifier. + error_code: The error code. + attributes: The attributes of the frame. + Returns: + The reset stream frame. + """ + return H2Frame( + stream_id, + H2FrameType.RST_STREAM, + error_code, + end_stream=True, + attributes=attributes, + ) + + @staticmethod + def create_window_update_frame( + stream_id: int, + increment: int, + attributes: Optional[Dict[str, str]] = None, + ) -> H2Frame: + """ + Create a window update frame. + Args: + stream_id: The stream identifier. + increment: The increment. + attributes: The attributes of the frame. + Returns: + The window update frame. + """ + return H2Frame( + stream_id, H2FrameType.WINDOW_UPDATE, increment, attributes=attributes + ) + + @staticmethod + def create_frame_by_event(event: Event) -> Optional[H2Frame]: + """ + Create a frame by the h2.events.Event. + Args: + event: The h2.events.Event. + Returns: + The H2Frame. None if the event is not supported or not implemented. + """ + if isinstance(event, (RequestReceived, ResponseReceived)): + # The headers frame. + return H2FrameUtils.create_headers_frame( + event.stream_id, event.headers, event.stream_ended is not None + ) + elif isinstance(event, TrailersReceived): + return H2FrameUtils.create_headers_frame( + event.stream_id, event.headers, end_stream=True + ) + elif isinstance(event, DataReceived): + # The data frame. + return H2FrameUtils.create_data_frame( + event.stream_id, + event.data, + end_stream=event.stream_ended is not None, + attributes={"flow_controlled_length": event.flow_controlled_length}, + ) + elif isinstance(event, StreamReset): + # The reset stream frame. + return H2FrameUtils.create_reset_stream_frame( + event.stream_id, event.error_code + ) + elif isinstance(event, WindowUpdated): + # The window update frame. + return H2FrameUtils.create_window_update_frame(event.stream_id, event.delta) diff --git a/dubbo/remoting/aio/h2_protocol.py b/dubbo/remoting/aio/h2_protocol.py new file mode 100644 index 0000000..1707f7c --- /dev/null +++ b/dubbo/remoting/aio/h2_protocol.py @@ -0,0 +1,341 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from typing import Dict, Optional, Tuple + +from h2.config import H2Configuration +from h2.connection import H2Connection + +from dubbo.logger.logger_factory import loggerFactory +from dubbo.remoting.aio.h2_frame import H2Frame, H2FrameType, H2FrameUtils +from dubbo.remoting.aio.h2_stream_handler import StreamHandler + +logger = loggerFactory.get_logger(__name__) + + +class DataFlowControl: + """ + DataFlowControl is responsible for managing HTTP/2 data flow, handling flow control, + and ensuring data frames are sent according to the HTTP/2 flow control rules. + + Note: + The class is not thread-safe and does not need to be designed as thread-safe + because there can be only one DataFlowControl corresponding to an HTTP2 connection. + + Args: + protocol (H2Protocol): The protocol instance used to send frames. + loop (asyncio.AbstractEventLoop): The asyncio event loop. + """ + + def __init__(self, protocol, loop: asyncio.AbstractEventLoop): + # The protocol instance used to send frames. + self.protocol: H2Protocol = protocol + + # The asyncio event loop. + self.loop = loop + + # Queue for storing data to be sent out + self._outbound_data_queue: asyncio.Queue[Tuple[H2Frame, asyncio.Event]] = ( + asyncio.Queue() + ) + + # Dictionary for storing data that could not be sent due to flow control limits + self._flow_control_data: Dict[int, Tuple[H2Frame, asyncio.Event]] = {} + + # Set of streams that need to be reset + self._reset_streams = set() + + # Task for the data sender loop. + self._data_sender_loop_task = None + + def start(self) -> None: + """ + Start the data sender loop. + This creates and starts an asyncio task that runs the _data_sender_loop coroutine. + """ + # Start the data sender loop + self._data_sender_loop_task = self.loop.create_task(self._data_sender_loop()) + + def cancel(self) -> None: + """ + Cancel the data sender loop. + This cancels the asyncio task running the _data_sender_loop coroutine. + """ + if self._data_sender_loop_task: + self._data_sender_loop_task.cancel() + + def put(self, frame: H2Frame, event: asyncio.Event) -> None: + """ + Put a data frame into the outbound data queue. + + Args: + frame (H2Frame): The data frame to send. + event (asyncio.Event): The event to notify when the data frame is sent. + """ + self._outbound_data_queue.put_nowait((frame, event)) + + def release(self, frame: H2Frame) -> None: + """ + Release the flow control for the stream. + + Args: + frame (H2Frame): The data frame to release the flow control. + It must be a WINDOW_UPDATE frame. + """ + if frame.frame_type != H2FrameType.WINDOW_UPDATE: + raise TypeError("The frame is not a window update frame") + + stream_id = frame.stream_id + if stream_id: + # This is specific to a single stream. + if stream_id in self._flow_control_data: + data_frame_event = self._flow_control_data.pop(stream_id) + self._outbound_data_queue.put_nowait(data_frame_event) + else: + # This is for the entire connection. + for data_frame_event in self._flow_control_data.values(): + self._outbound_data_queue.put_nowait(data_frame_event) + # Clear the pending data + self._flow_control_data = {} + + def reset(self, frame: H2Frame) -> None: + """ + Reset the stream. + + Args: + frame (H2Frame): The reset frame. It must be an RST_STREAM frame. + """ + if frame.frame_type != H2FrameType.RST_STREAM: + raise TypeError("The frame is not a reset stream frame") + + if frame.stream_id in self._flow_control_data: + del self._flow_control_data[frame.stream_id] + + self._reset_streams.add(frame.stream_id) + + async def _data_sender_loop(self) -> None: + """ + Coroutine that continuously sends data frames from the outbound data queue + while respecting flow control limits. + """ + while True: + # Get the frame from the outbound data queue -> it's a blocking operation, but asynchronous. + data_frame: H2Frame + event: asyncio.Event + data_frame, event = await self._outbound_data_queue.get() + + # If the frame is not a data frame, ignore it. + if data_frame.frame_type != H2FrameType.DATA: + logger.warning(f"Invalid frame type: {data_frame.frame_type}, ignored") + event.set() + continue + + # Get the stream ID and data from the frame. + stream_id = data_frame.stream_id + data = data_frame.data + end_stream = data_frame.end_stream + + # The stream has been reset, so we don't send any data. + if stream_id in self._reset_streams: + event.set() + continue + + # We need to send data, but not to exceed the flow control window. + window_size = self.protocol.conn.local_flow_control_window(stream_id) + chunk_size = min(window_size, len(data)) + data_to_send = data[:chunk_size] + data_to_buffer = data[chunk_size:] + + if data_to_send: + # Send the data frame + max_size = self.protocol.conn.max_outbound_frame_size + + # Split the data into chunks and send them out + for x in range(0, len(data), max_size): + chunk = data[x : x + max_size] + end_stream_flag = ( + end_stream + and data_to_buffer == b"" + and x + max_size >= len(data) + ) + self.protocol.conn.send_data( + stream_id, chunk, end_stream=end_stream_flag + ) + + self.protocol.transport.write(self.protocol.conn.data_to_send()) + elif end_stream: + # If there is no data to send, but the stream is ended, send an empty data frame. + self.protocol.conn.send_data(stream_id, b"", end_stream=True) + self.protocol.transport.write(self.protocol.conn.data_to_send()) + + if data_to_buffer: + # Store the data that could not be sent due to flow control limits + data_frame.data = data_to_buffer + self._flow_control_data[stream_id] = (data_frame, event) + else: + # We sent everything. + event.set() + + +class H2Protocol(asyncio.Protocol): + """ + Implements an HTTP/2 protocol using asyncio's Protocol class. + + This class sets up and manages an HTTP/2 connection using the h2 library. + It handles connection state, stream mapping, and data flow control. + + Args: + h2_config (H2Configuration): The configuration for the H2 connection. + stream_handler (StreamHandler): The handler for managing streams. + + """ + + def __init__(self, h2_config: H2Configuration, stream_handler: StreamHandler): + # Create the H2 state machine + self.conn: H2Connection = H2Connection(config=h2_config) + + # the backing transport. + self.transport: Optional[asyncio.Transport] = None + + # The asyncio event loop. + self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + + # A mapping of stream ID to stream object. + self._stream_handler: StreamHandler = stream_handler + + self._data_follow_control: Optional[DataFlowControl] = None + + def connection_made(self, transport: asyncio.Transport) -> None: + """ + Called when the connection is first established. We complete the following actions: + 1. Save the transport. + 2. Initialize the H2 connection. + 3. Initialize the StreamHandler. + 3. Create the data follow control and start the task. + """ + self.transport = transport + self.conn.initiate_connection() + self.transport.write(self.conn.data_to_send()) + + # Initialize the StreamHandler + self._stream_handler.init(self.loop, self) + + # Create the data follow control object and start the task. + self._data_follow_control = DataFlowControl(self, self.loop) + self._data_follow_control.start() + + def connection_lost(self, exc) -> None: + """ + Called when the connection is lost. + Args: + exc: The exception that caused the connection to be lost. + """ + self._stream_handler.destroy() + self._data_follow_control.cancel() + + def send_headers_frame(self, headers_frame: H2Frame) -> asyncio.Event: + """ + Send headers to the remote peer. (thread-safe) + Note: + Only the first call sends a head frame, if called again, a trailer frame is sent. + Args: + headers_frame(H2Frame): The headers frame to send. + Returns: + asyncio.Event: The event that is set when the headers frame is sent. + """ + headers_event = asyncio.Event() + + def _inner_send_headers_frame(headers_frame: H2Frame, event: asyncio.Event): + self.conn.send_headers( + headers_frame.stream_id, headers_frame.data, headers_frame.end_stream + ) + self.transport.write(self.conn.data_to_send()) + # Set the event to indicate that the headers frame has been sent. + event.set() + + # Send the header frame + self.loop.call_soon_threadsafe( + _inner_send_headers_frame, headers_frame, headers_event + ) + + return headers_event + + def send_data_frame(self, data_frame: H2Frame) -> asyncio.Event: + """ + Send data to the remote peer. (thread-safe) + The sending of data frames is subject to traffic control. + Args: + data_frame(H2Frame): The data frame to send. + Returns: + asyncio.Event: The event that is set when the data frame is sent. + """ + data_event = asyncio.Event() + + def _inner_send_data_frame(_data_frame: H2Frame, event: asyncio.Event): + self._data_follow_control.put(_data_frame, event) + + self.loop.call_soon_threadsafe(_inner_send_data_frame, data_frame, data_event) + + return data_event + + def send_reset_frame(self, reset_frame: H2Frame) -> None: + """ + Send the reset frame to the remote peer.(thread-safe) + Args: + reset_frame(H2Frame): The reset frame to send. + """ + + def _inner_send_reset_frame(_reset_frame: H2Frame): + self.conn.reset_stream(_reset_frame.stream_id, _reset_frame.data) + self.transport.write(self.conn.data_to_send()) + # remove the stream from the stream handler + self._stream_handler.remove(_reset_frame.stream_id) + + self.loop.call_soon_threadsafe(_inner_send_reset_frame, reset_frame) + + def data_received(self, data: bytes) -> None: + """ + Process inbound data. + """ + events = self.conn.receive_data(data) + # Process the event + for event in events: + frame = H2FrameUtils.create_frame_by_event(event) + if not frame: + # If frame is None, there are two possible cases: + # 1. Events that are handled automatically by the H2 library. -> We just need to send it. + # e.g. RemoteSettingsChanged, PingReceived + # 2. Events that are not implemented or do not require attention. -> We'll ignore it for now. + pass + else: + # The frames we focus on include: HEADERS, DATA, WINDOW_UPDATE, RST_STREAM + if frame.frame_type == H2FrameType.WINDOW_UPDATE: + # Update the flow control window + self._data_follow_control.release(frame) + else: + # Handle the frame + self._stream_handler.handle_frame(frame) + + # Acknowledge the received data + if frame.frame_type == H2FrameType.DATA: + self.conn.acknowledge_received_data( + frame.attributes["flow_controlled_length"], frame.stream_id + ) + + # If there is data to send, send it. + outbound_data = self.conn.data_to_send() + if outbound_data: + self.transport.write(outbound_data) diff --git a/dubbo/remoting/aio/h2_stream.py b/dubbo/remoting/aio/h2_stream.py new file mode 100644 index 0000000..5880fee --- /dev/null +++ b/dubbo/remoting/aio/h2_stream.py @@ -0,0 +1,366 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from typing import List, Optional, Tuple + +from dubbo.logger.logger_factory import loggerFactory +from dubbo.remoting.aio.h2_frame import ( + DATA_COMPLETED_FRAME, + H2Frame, + H2FrameType, + H2FrameUtils, +) + +logger = loggerFactory.get_logger(__name__) + + +class StreamFrameControl: + """ + This class is responsible for controlling the order and sending of frames in an HTTP/2 stream. + It ensures that frames are sent in the correct sequence, specifically HEADERS, DATA (0 or more), + and optional TRAILERS. + + Note: + 1. + This class is not thread-safe and does not need to be designed as thread-safe because it + is used only within a single Stream object. However, asynchronous call safety must be ensured. + 2. Special frames like RESET can be sent without following this sequence. + 3. Each Stream object corresponds to a StreamFrameControl object. + + + Args: + protocol(H2Protocol): The protocol instance used to send frames. + loop(asyncio.AbstractEventLoop): The asyncio event loop. + """ + + def __init__(self, protocol, loop: asyncio.AbstractEventLoop): + # Import here to avoid looping imports + from dubbo.remoting.aio.h2_protocol import H2Protocol + + # The protocol instance used to send frames. + self._protocol: H2Protocol = protocol + + # The asyncio event loop. + self._loop = loop + + # The queue for storing frames + # HEADERS: 0, DATA: 1, TRAILERS: 2 + self._frame_queue = asyncio.PriorityQueue() + + # The event for the start of the stream -> Ensure that HEADERS frame have been placed in the queue + self._start_event: asyncio.Event = asyncio.Event() + + # The event for the headers frame -> Ensure that HEADERS frame have been sent + self._headers_event: Optional[asyncio.Event] = None + + # The event for the data frame -> Ensure that previous DATA frame have been sent + self._data_event: Optional[asyncio.Event] = None + + # The flag to indicate whether the data is completed -> Ensure that all data frames have been placed in the queue + self._data_completed = False + + # TRAILERS frame storage + self._trailers_frame: Optional[H2Frame] = None + + self._frame_sender_loop_task = None + + def start(self): + """ + Start the frame sender loop. + This creates and starts an asyncio task that runs the _frame_sender_loop coroutine. + """ + self._frame_sender_loop_task = self._loop.create_task(self._frame_sender_loop()) + + def cancel(self): + """ + Cancel the frame sender loop. + This cancels the asyncio task running the _frame_sender_loop coroutine. + """ + if self._frame_sender_loop_task: + self._frame_sender_loop_task.cancel() + + def put_headers(self, headers_frame: H2Frame): + """ + Put a HEADERS frame into the frame queue. + + Args: + headers_frame (H2Frame): The HEADERS frame to be added. + + Raises: + TypeError: If the frame is not a HEADERS frame. + """ + if headers_frame.frame_type != H2FrameType.HEADERS: + raise TypeError("The frame is not a HEADERS frame") + + # If the start event is not set, set it. + if not self._start_event.is_set(): + # HEADERS + self._frame_queue.put_nowait((0, headers_frame)) + self._start_event.set() + else: + # TRAILERS + self.put_trailers_later(headers_frame) + + def put_data(self, data_frame: H2Frame): + """ + Put a DATA frame into the frame queue. + + Args: + data_frame (H2Frame): The DATA frame to be added. + + Raises: + TypeError: If the frame is not a DATA frame. + RuntimeError: If the data is completed, no more data can be sent. + """ + if data_frame.frame_type != H2FrameType.DATA: + raise TypeError("The frame is not a DATA frame") + elif self._data_completed: + raise RuntimeError("The data is completed, no more data can be sent.") + + if data_frame == DATA_COMPLETED_FRAME: + # The data is completed + self._data_completed = True + if self._trailers_frame: + # Make sure TRAILERS are sent after DATA + self.put_trailers_now(self._trailers_frame) + else: + self._data_completed = data_frame.end_stream + self._frame_queue.put_nowait((1, data_frame)) + + def put_trailers_now(self, trailers_frame: H2Frame): + """ + Immediately put a TRAILERS frame into the frame queue. + + Note: You should call this method when you don't need to send DATA. + + Args: + trailers_frame (H2Frame): The TRAILERS frame to be added. + + Raises: + TypeError: If the frame is not a HEADERS frame. + """ + if trailers_frame.frame_type != H2FrameType.HEADERS: + raise TypeError("The frame is not a HEADERS frame") + + self._frame_queue.put_nowait((2, trailers_frame)) + + def put_trailers_later(self, trailers_frame: H2Frame): + """ + Store the TRAILERS frame to be sent after all DATA frames. + + Note: When you need to send DATA, you should call this method. + + Args: + trailers_frame (H2Frame): The TRAILERS frame to be stored. + + Raises: + TypeError: If the frame is not a HEADERS frame. + """ + self._trailers_frame = trailers_frame + + async def _frame_sender_loop(self): + """ + The main loop for sending frames. This loop continuously fetches frames from the queue and sends them in the + correct order. + + It ensures that HEADERS frames are sent before any DATA frames, and waits for the completion events of HEADERS + and DATA frames before sending subsequent frames. + + If a frame has the end_stream flag set, the loop breaks, indicating the end of the stream. + """ + while True: + # Wait for the start event + await self._start_event.wait() + + # Get the frame from the outbound data queue -> it's a blocking operation, but asynchronous. + priority, frame = await self._frame_queue.get() + + # If the frame is HEADERS, send the header frame directly. + if frame.frame_type == H2FrameType.HEADERS and not self._headers_event: + self._headers_event = self._protocol.send_headers_frame(frame) + else: + # Wait for HEADERS to be sent. + await self._headers_event.wait() + + # Waiting for the previous DATA to be sent. + if self._data_event: + await self._data_event.wait() + + if frame.frame_type == H2FrameType.DATA: + # Send the data frame and store the event. + self._data_event = self._protocol.send_data_frame(frame) + elif frame.frame_type == H2FrameType.HEADERS: + # Send the trailers frame. + self._protocol.send_headers_frame(frame) + + if frame.end_stream: + # The stream is completed. we can break the loop. + break + + +class Stream: + """ + Stream is a bidirectional channel that manipulates the data flow between peers. + + This class manages the sending and receiving of HTTP/2 frames for a single stream. + It ensures frames are sent in the correct order and handles flow control for the stream. + + Args: + stream_id (int): The stream identifier. + protocol (H2Protocol): The protocol instance used to send frames. + loop (asyncio.AbstractEventLoop): The asyncio event loop. + + """ + + def __init__(self, stream_id: int, protocol, loop: asyncio.AbstractEventLoop): + # import here to avoid circular import + from dubbo.remoting.aio.h2_protocol import H2Protocol + + # The protocol. + self._protocol: H2Protocol = protocol + + # The stream ID. + self._stream_id: int = stream_id + + # The asyncio event loop. + self._loop = loop + + # The stream frame control. + self._stream_frame_control = StreamFrameControl(protocol, loop) + self._stream_frame_control.start() + + # The flag to indicate whether the sending is completed. + self._send_completed = False + + # The flag to indicate whether the receiving is completed. + self._receive_completed = False + + def send_headers( + self, headers: List[Tuple[str, str]], end_stream: bool = False + ) -> None: + """ + Send the headers frame. The first call sends the head frame, the second call sends the trailer frame. + + Args: + headers (List[Tuple[str, str]]): The headers to send. + end_stream (bool): Whether to end the stream after sending this frame. + """ + if self._send_completed: + return + else: + self._send_completed = end_stream + + def _inner_send_headers(_headers: List[Tuple[str, str]], _end_stream: bool): + headers_frame = H2FrameUtils.create_headers_frame( + self._stream_id, _headers, _end_stream + ) + self._stream_frame_control.put_headers(headers_frame) + if end_stream: + # The data is completed. + self._stream_frame_control.put_data(DATA_COMPLETED_FRAME) + + self._loop.call_soon_threadsafe(_inner_send_headers, headers, end_stream) + + def close(self) -> None: + """ + Close the stream by cancelling the frame sender loop. + """ + self._stream_frame_control.cancel() + + def send_data(self, data: bytes, end_stream: bool = False) -> None: + """ + Send a data frame. + + Args: + data (bytes): The data to send. + end_stream (bool): Whether to end the stream after sending this frame. + """ + if self._send_completed: + logger.info("Send completed.") + return + else: + self._send_completed = end_stream + + def _inner_send_data(_data: bytes, _end_stream: bool): + data_frame = H2FrameUtils.create_data_frame( + self._stream_id, _data, _end_stream + ) + self._stream_frame_control.put_data(data_frame) + + self._loop.call_soon_threadsafe(_inner_send_data, data, end_stream) + + def send_reset(self, error_code: int) -> None: + """ + Send a reset frame to terminate the stream. + + Note: This is a special frame and does not need to follow the sequence of frames. + + Args: + error_code (int): The error code indicating the reason for the reset. + """ + self._send_completed = True + + def _inner_send_reset(_error_code: int): + reset_frame = H2FrameUtils.create_reset_stream_frame( + self._stream_id, _error_code + ) + self._protocol.send_reset_frame(reset_frame) + self._stream_frame_control.cancel() + + self._loop.call_soon_threadsafe(_inner_send_reset, error_code) + + def receive_headers(self, headers: List[Tuple[str, str]]) -> None: + """ + Called when a headers frame is received. + + Args: + headers (List[Tuple[str, str]]): The headers received. + """ + raise NotImplementedError("receive_headers() is not implemented") + + def receive_data(self, data: bytes) -> None: + """ + Called when a data frame is received. + + Args: + data (bytes): The data received. + """ + raise NotImplementedError("receive_data() is not implemented") + + def receive_complete(self) -> None: + """ + Called when the stream is completed. + """ + self._receive_completed = True + + def cancel_by_remote(self, err_code: int) -> None: + """ + Called when the stream is cancelled by the remote peer. + + Args: + err_code (int): The error code indicating the reason for cancellation. + """ + raise NotImplementedError("cancel_by_remote() is not implemented") + + +class ClientStream(Stream): + # TODO implement the ClientStream + pass + + +class ServerStream(Stream): + # TODO implement the ServerStream + pass diff --git a/dubbo/remoting/aio/h2_stream_handler.py b/dubbo/remoting/aio/h2_stream_handler.py new file mode 100644 index 0000000..257bcfc --- /dev/null +++ b/dubbo/remoting/aio/h2_stream_handler.py @@ -0,0 +1,169 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from concurrent.futures import Future as ThreadingFuture +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, Optional + +from dubbo.logger.logger_factory import loggerFactory +from dubbo.remoting.aio.h2_frame import H2Frame, H2FrameType +from dubbo.remoting.aio.h2_stream import ClientStream, ServerStream, Stream + +logger = loggerFactory.get_logger(__name__) + + +class StreamHandler: + """ + Stream handler class. It is used to handle the stream in the connection. + Args: + executor(ThreadPoolExecutor): The executor to handle the frame. + """ + + def __init__( + self, + executor: Optional[ThreadPoolExecutor] = None, + ): + # import here to avoid circular import + from dubbo.remoting.aio.h2_protocol import H2Protocol + + self._protocol: Optional[H2Protocol] = None + + # The event loop to run the asynchronous function. + self._loop: Optional[asyncio.AbstractEventLoop] = asyncio.get_event_loop() + + # The streams managed by the handler + self._streams: Dict[int, Stream] = {} + + # The executor to handle the frame, If None, the default executor will be used. + self._executor = executor + + def init(self, loop: asyncio.AbstractEventLoop, protocol) -> None: + """ + Initialize the handler with the protocol. + Args: + loop(asyncio.AbstractEventLoop): The event loop. + protocol(H2Protocol): The protocol. + """ + self._loop = loop + self._protocol = protocol + + def handle_frame(self, frame: H2Frame) -> None: + """ + Handle the frame received from the connection. + Args: + frame: The frame to handle. + """ + # Handle the frame in the executor + self._loop.run_in_executor(self._executor, self._handle_in_executor, frame) + + def _handle_in_executor(self, frame: H2Frame) -> None: + """ + Actually handle the frame in the executor. + Args: + frame: The frame to handle. + """ + stream = self._streams.get(frame.stream_id) + + if not stream: + logger.warning(f"Unknown stream: id={frame.stream_id}") + return + + frame_type = frame.frame_type + if frame_type == H2FrameType.HEADERS: + stream.receive_headers(frame.data) + elif frame_type == H2FrameType.DATA: + stream.receive_data(frame.data) + elif frame_type == H2FrameType.RST_STREAM: + stream.cancel_by_remote(frame.data) + else: + logger.debug(f"Unhandled frame: {frame_type}") + + if frame.end_stream: + stream.receive_complete() + + def create(self) -> Stream: + """ + Create a new stream. -> Client + Returns: + Stream: The stream object. + """ + raise NotImplementedError("create() is not implemented") + + def register(self, stream_id: int) -> None: + """ + Register the stream to the handler -> Server + Args: + stream_id: The stream ID. + """ + raise NotImplementedError("register() is not implemented") + + def remove(self, stream_id: int) -> None: + """ + Remove the stream from the handler -> Server + Args: + stream_id: The stream ID. + """ + del self._streams[stream_id] + + def destroy(self) -> None: + """ + Destroy the handler. + """ + for stream in self._streams.values(): + stream.close() + self._streams.clear() + + +class ClientStreamHandler(StreamHandler): + + def create(self) -> Stream: + """ + Create a new stream. -> Client + """ + # Create a new client stream + future = ThreadingFuture() + + def _inner_create(future: ThreadingFuture): + new_stream_id = self._protocol.conn.get_next_available_stream_id() + new_stream = ClientStream(new_stream_id, self._protocol, self._loop) + self._streams[new_stream_id] = new_stream + future.set_result(new_stream) + + self._loop.call_soon_threadsafe(_inner_create, future) + return future.result() + + # TODO implement ClientStreamHandler... + + +class ServerStreamHandler(StreamHandler): + + def register(self, stream_id: int) -> None: + """ + Register the stream to the handler -> Server + """ + new_stream = ServerStream(stream_id, self._protocol, self._loop) + self._streams[stream_id] = new_stream + + def handle_frame(self, frame: H2Frame) -> None: + # Register the stream if it is a HEADERS frame and the stream is not registered. + if ( + frame.frame_type == H2FrameType.HEADERS + and frame.stream_id not in self._streams + ): + self.register(frame.stream_id) + super().handle_frame(frame) + + # TODO implement ServerStreamHandler... diff --git a/dubbo/remoting/aio/http2_protocol.py b/dubbo/remoting/aio/http2_protocol.py deleted file mode 100644 index cd5e064..0000000 --- a/dubbo/remoting/aio/http2_protocol.py +++ /dev/null @@ -1,327 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -from typing import List, Optional, Tuple - -import h2.events -from h2.config import H2Configuration -from h2.connection import H2Connection -from h2.events import ( - DataReceived, - PingReceived, - RemoteSettingsChanged, - RequestReceived, - ResponseReceived, - StreamEnded, - StreamReset, - TrailersReceived, - WindowUpdated, -) - -from dubbo.logger.logger_factory import loggerFactory -from dubbo.remoting.aio.constants import END_DATA_SENTINEL - -logger = loggerFactory.get_logger(__name__) - - -class HTTP2Protocol(asyncio.Protocol): - - def __init__(self, h2_config: H2Configuration): - # Create the H2 state machine - self.conn: H2Connection = H2Connection(config=h2_config) - - # the backing transport. - self.transport: Optional[asyncio.Transport] = None - - # The asyncio event loop. - self._loop = asyncio.get_running_loop() - - # A mapping of stream ID to stream object. - self.streams = {} - - # The `write_data_queue`, `flow_controlled_data`, and `send_data_loop_task` together form the flow control mechanism. - # Data flows between `write_queue` and `flow_controlled_data`. - # The `send_data_loop_task` blocks while reading data from the `write_queue` and attempts to send it. - # If a flow control limit is encountered, the unsent data is stored in `flow_controlled_data`, - # awaiting a WINDOW_UPDATE frame, at which point it is moved back from `flow_controlled_data` to `write_queue`. - self._write_data_queue = asyncio.Queue() - self._flow_controlled_data = {} - self._send_data_loop_task = None - - # Any streams that have been remotely reset. - self._reset_streams = set() - - def connection_made(self, transport: asyncio.Transport) -> None: - """ - Called when the connection is first established. We complete the following actions: - 1. Save the transport. - 2. Initialize the H2 connection. - 3. Create the send data loop task. - """ - self.transport = transport - self.conn.initiate_connection() - self.transport.write(self.conn.data_to_send()) - self._send_data_loop_task = self._loop.create_task(self._send_data_loop()) - - def connection_lost(self, exc) -> None: - """ - Called when the connection is lost. - """ - self._send_data_loop_task.cancel() - - def send_head_frame( - self, - stream_id: int, - headers: List[Tuple[str, str]], - end_stream=False, - head_event: Optional[asyncio.Event] = None, - ) -> asyncio.Event: - """ - Send headers to the remote peer. - Because flow control is only for data frames, we can directly send the head frame rate. - Note: Only the first call sends a head frame, if called again, a trailer frame is sent. - """ - head_event = head_event or asyncio.Event() - - def _inner_send_header_frame(stream_id, headers, event): - self.conn.send_headers(stream_id, headers, end_stream) - self.transport.write(self.conn.data_to_send()) - event.set() - - # Send the header frame - self._loop.call_soon_threadsafe( - _inner_send_header_frame, stream_id, headers, head_event - ) - - return head_event - - def send_data_frame(self, stream_id: int, data) -> asyncio.Event: - """ - Send data to the remote peer. - The sending of data frames is subject to traffic control, - so we put them in a queue and send them according to traffic control rules - Args: - stream_id: stream id - data: data - """ - event = asyncio.Event() - - def _inner_send_data_frame(stream_id: int, data, event: asyncio.Event): - self._write_data_queue.put_nowait((stream_id, data, event)) - - self._loop.call_soon_threadsafe(_inner_send_data_frame, stream_id, data, event) - - return event - - async def _send_data_loop(self) -> None: - """ - Asynchronous blocking to get data from write_data_queue and try to send it, - this method implements the flow control mechanism - """ - while True: - stream_id, data, event = await self._write_data_queue.get() - - # If this stream got reset, just drop the data on the floor. - if stream_id in self._reset_streams: - event.set() - continue - - if data is END_DATA_SENTINEL: - self.conn.end_stream(stream_id) - self.transport.write(self.conn.data_to_send()) - event.set() - continue - - # We need to send data, but not to exceed the flow control window. - window_size = self.conn.local_flow_control_window(stream_id) - chunk_size = min(window_size, len(data)) - data_to_send = data[:chunk_size] - data_to_buffer = data[chunk_size:] - - if data_to_send: - # Send the data frame - max_size = self.conn.max_outbound_frame_size - chunks = ( - data_to_send[x : x + max_size] - for x in range(0, len(data_to_send), max_size) - ) - for chunk in chunks: - self.conn.send_data(stream_id, chunk) - self.transport.write(self.conn.data_to_send()) - - if data_to_buffer: - # We still have data to send, but it's blocked by traffic control, - # so we need to wait for the traffic window to open again. - self._flow_controlled_data[stream_id] = ( - stream_id, - data_to_buffer, - event, - ) - else: - # We sent everything. - event.set() - - def data_received(self, data: bytes) -> None: - """ - Process inbound data. - """ - events = self.conn.receive_data(data) - for event in events: - self._process_event(event) - outbound_data = self.conn.data_to_send() - if outbound_data: - self.transport.write(outbound_data) - - def _process_event(self, event: h2.events.Event) -> Optional[bool]: - """ - Process an event. - """ - if isinstance(event, (RemoteSettingsChanged, PingReceived)): - # Events that are handled automatically by the H2 library. - # 1. RemoteSettingsChanged: h2 automatically acknowledges settings changes - # 2. PingReceived: A ping acknowledgment with the same opaque data is automatically emitted after receiving a ping. - pass - elif isinstance(event, WindowUpdated): - self.window_updated(event) - elif isinstance(event, StreamReset): - self.reset_stream(event) - else: - # A False here means that the current event is not handled and needs to be handled by the subclass. - return False - - def window_updated(self, event: WindowUpdated) -> None: - """ - The flow control window got opened. - - """ - if event.stream_id: - # This is specific to a single stream. - if event.stream_id in self._flow_controlled_data: - self._write_data_queue.put_nowait( - self._flow_controlled_data.pop(event.stream_id) - ) - else: - # This event is specific to the connection. - # Free up all the streams. - for data in self._flow_controlled_data.values(): - self._write_data_queue.put_nowait(data) - - self._flow_controlled_data = {} - - def reset_stream(self, event: StreamReset) -> None: - """ - The remote peer reset the stream. - """ - if event.stream_id in self._flow_controlled_data: - del self._flow_controlled_data - - self._reset_streams.add(event.stream_id) - - -class HTTP2ClientProtocol(HTTP2Protocol): - """ - An HTTP/2 client protocol. - """ - - def __init__(self): - h2_config = H2Configuration(client_side=True, header_encoding="utf-8") - super().__init__(h2_config) - - def register_stream(self, stream_id, stream): - self.streams[stream_id] = stream - - def _process_event(self, event): - if super()._process_event(event) is False: - if isinstance(event, ResponseReceived): - self.receive_headers(event) - elif isinstance(event, DataReceived): - self.receive_data(event) - elif isinstance(event, TrailersReceived): - self.receive_trailers(event) - elif isinstance(event, StreamEnded): - self.stream_ended(event) - - def receive_headers(self, event: ResponseReceived): - """ - The response headers have been received. - """ - self.streams[event.stream_id].receive_headers(event.headers) - - def receive_data(self, event: DataReceived): - """ - Data has been received. - """ - self.streams[event.stream_id].receive_data(event.data) - # Acknowledge the data, so the remote peer can send more. - self.conn.acknowledge_received_data( - event.flow_controlled_length, event.stream_id - ) - - def receive_trailers(self, event): - """ - Trailers have been received. - """ - self.streams[event.stream_id].receive_trailers(event.headers) - - def stream_ended(self, event): - """ - The stream has ended. - """ - self.streams[event.stream_id].receive_complete() - # Clean up the stream. - del self.streams[event.stream_id] - - def reset_stream(self, event: StreamReset) -> None: - super().reset_stream(event) - # TODO Pass the exception to the corresponding stream object - - -class HTTP2ServerProtocol(HTTP2Protocol): - - def __init__(self): - h2_config = H2Configuration(client_side=False, header_encoding="utf-8") - super().__init__(h2_config) - - def _process_event(self, event: h2.events.Event): - if super()._process_event(event) is False: - if isinstance(event, RequestReceived): - self.receive_headers(event) - elif isinstance(event, DataReceived): - self.receive_data(event) - elif isinstance(event, StreamEnded): - self.stream_ended(event) - - def receive_headers(self, event: RequestReceived): - """ - The request headers have been received. - """ - from dubbo.remoting.aio.aio_stream import AioServerStream - - s = AioServerStream(event.stream_id, self._loop, self) - self.streams[event.stream_id] = s - s.receive_headers(event.headers) - - def receive_data(self, event: DataReceived): - """ - Data has been received. - """ - self.streams[event.stream_id].receive_data(event.data) - - def stream_ended(self, event: StreamEnded): - """ - The stream has ended. - """ - self.streams[event.stream_id].receive_complete() From cd3c39e1ad4fb15c1265dd4ec018ad00fd3b6256 Mon Sep 17 00:00:00 2001 From: zaki Date: Mon, 8 Jul 2024 22:48:07 +0800 Subject: [PATCH 27/32] feat: Complete the client's call link --- dubbo/_dubbo.py | 3 +- dubbo/callable.py | 59 ++++++ dubbo/callable/rpc_callable.py | 79 ------- dubbo/callable/rpc_callable_factory.py | 37 ---- dubbo/client/client.py | 89 ++++---- dubbo/common/__init__.py | 15 -- dubbo/common/constants/__init__.py | 15 -- dubbo/compressor/compressor.py | 22 +- dubbo/config/logger_config.py | 8 +- dubbo/config/reference_config.py | 14 +- dubbo/{callable => constants}/__init__.py | 0 .../constants/common_constants.py | 13 +- .../constants/logger_constants.py | 0 .../{common => }/constants/type_constants.py | 0 dubbo/extension/__init__.py | 3 +- dubbo/extension/registry.py | 15 +- dubbo/logger/logger.py | 4 +- dubbo/logger/logger_factory.py | 8 +- dubbo/logger/logging/logger.py | 2 +- dubbo/logger/logging/logger_adapter.py | 8 +- dubbo/loop/__init__.py | 58 ------ dubbo/loop/loop_manger.py | 111 ---------- dubbo/{common => }/node.py | 2 +- dubbo/protocol/invocation.py | 69 ++++-- dubbo/protocol/invoker.py | 2 +- dubbo/protocol/protocol.py | 2 +- dubbo/protocol/result.py | 33 ++- dubbo/protocol/triple/stream.py | 119 ----------- dubbo/protocol/triple/tri_client.py | 196 ++++++++++++++++++ dubbo/protocol/triple/tri_codec.py | 196 ++++++++++++++++++ dubbo/protocol/triple/tri_decoder.py | 152 -------------- dubbo/protocol/triple/tri_invoker.py | 116 ++++++++++- .../{triple_protocol.py => tri_listener.py} | 19 +- dubbo/protocol/triple/tri_protocol.py | 58 ++++++ dubbo/protocol/triple/tri_rpc_status.py | 57 +++++ dubbo/remoting/aio/aio_transporter.py | 140 +++++++++++-- dubbo/remoting/aio/constants.py | 18 -- dubbo/remoting/aio/h2_frame.py | 11 +- dubbo/remoting/aio/h2_protocol.py | 45 +++- dubbo/remoting/aio/h2_stream.py | 119 ++++++++--- dubbo/remoting/aio/h2_stream_handler.py | 44 ++-- dubbo/remoting/aio/loop.py | 150 ++++++++++++++ dubbo/remoting/transporter.py | 58 +++++- dubbo/serialization.py | 4 +- dubbo/{common => }/url.py | 38 +++- tests/common/tets_url.py | 16 +- tests/logger/test_logger_factory.py | 4 +- tests/logger/test_logging_logger.py | 2 +- 48 files changed, 1399 insertions(+), 834 deletions(-) create mode 100644 dubbo/callable.py delete mode 100644 dubbo/callable/rpc_callable.py delete mode 100644 dubbo/callable/rpc_callable_factory.py delete mode 100644 dubbo/common/__init__.py delete mode 100644 dubbo/common/constants/__init__.py rename dubbo/{callable => constants}/__init__.py (100%) rename dubbo/{common => }/constants/common_constants.py (78%) rename dubbo/{common => }/constants/logger_constants.py (100%) rename dubbo/{common => }/constants/type_constants.py (100%) delete mode 100644 dubbo/loop/__init__.py delete mode 100644 dubbo/loop/loop_manger.py rename dubbo/{common => }/node.py (97%) delete mode 100644 dubbo/protocol/triple/stream.py create mode 100644 dubbo/protocol/triple/tri_client.py create mode 100644 dubbo/protocol/triple/tri_codec.py delete mode 100644 dubbo/protocol/triple/tri_decoder.py rename dubbo/protocol/triple/{triple_protocol.py => tri_listener.py} (68%) create mode 100644 dubbo/protocol/triple/tri_protocol.py create mode 100644 dubbo/protocol/triple/tri_rpc_status.py delete mode 100644 dubbo/remoting/aio/constants.py create mode 100644 dubbo/remoting/aio/loop.py rename dubbo/{common => }/url.py (92%) diff --git a/dubbo/_dubbo.py b/dubbo/_dubbo.py index fece509..05a096f 100644 --- a/dubbo/_dubbo.py +++ b/dubbo/_dubbo.py @@ -16,7 +16,8 @@ import threading from typing import Dict, List -from dubbo.config import ApplicationConfig, ConsumerConfig, LoggerConfig, ProtocolConfig +from dubbo.config import (ApplicationConfig, ConsumerConfig, LoggerConfig, + ProtocolConfig) from dubbo.logger.logger_factory import loggerFactory logger = loggerFactory.get_logger(__name__) diff --git a/dubbo/callable.py b/dubbo/callable.py new file mode 100644 index 0000000..749dddb --- /dev/null +++ b/dubbo/callable.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +from dubbo.constants import common_constants +from dubbo.protocol.invocation import RpcInvocation +from dubbo.protocol.invoker import Invoker +from dubbo.url import URL + + +class RpcCallable: + + def __init__(self, invoker: Invoker, url: URL): + self._invoker = invoker + self._url = url + self._service_name = self._url.path or "" + self._method_name = self._url.get_parameter(common_constants.METHOD_KEY) or "" + self._call_type = self._url.get_parameter(common_constants.CALL_KEY) + self._request_serializer = ( + self._url.get_attribute(common_constants.SERIALIZATION) or None + ) + self._response_serializer = ( + self._url.get_attribute(common_constants.DESERIALIZATION) or None + ) + + def _do_call(self, argument: Any) -> Any: + """ + Real call method. + """ + # Create a new RpcInvocation object. + invocation = RpcInvocation( + self._service_name, + self._method_name, + argument, + attributes={ + common_constants.CALL_KEY: self._call_type, + common_constants.SERIALIZATION: self._request_serializer, + common_constants.DESERIALIZATION: self._response_serializer, + }, + ) + # Do invoke. + result = self._invoker.invoke(invocation) + return result.get_value() + + def __call__(self, argument: Any) -> Any: + return self._do_call(argument) diff --git a/dubbo/callable/rpc_callable.py b/dubbo/callable/rpc_callable.py deleted file mode 100644 index 9171e1f..0000000 --- a/dubbo/callable/rpc_callable.py +++ /dev/null @@ -1,79 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -from typing import Any - -from dubbo.common.constants import common_constants -from dubbo.common.url import URL -from dubbo.protocol.invocation import RpcInvocation -from dubbo.protocol.invoker import Invoker - - -class RpcCallable: - - def __init__(self, invoker: Invoker, url: URL): - self._invoker = invoker - self._url = url - self._service_name = self._url.path or "" - method_url = self._url.get_attribute(common_constants.METHOD_KEY) - self._method_name = method_url.get_parameter(common_constants.METHOD_KEY) or "" - self._call_type = method_url.get_parameter(common_constants.TYPE_CALL) - self._req_serializer = ( - method_url.get_attribute(common_constants.SERIALIZATION) or None - ) - self._res_serializer = ( - method_url.get_attribute(common_constants.SERIALIZATION) or None - ) - - async def _do_call(self, argument: Any): - """ - Real call method. - """ - if ( - self._call_type == common_constants.CALL_CLIENT_STREAM - and not inspect.isgeneratorfunction(argument) - ): - raise ValueError( - "Invalid argument: The provided argument must be a generator function " - ) - elif ( - self._call_type == common_constants.CALL_UNARY - and inspect.isgeneratorfunction(argument) - ): - raise ValueError( - "Invalid argument: The provided argument must be a normal function" - ) - - # Create a new RpcInvocation object. - invocation = RpcInvocation( - self._service_name, - self._method_name, - argument, - self._req_serializer, - self._res_serializer, - ) - # Do invoke. - result = self._invoker.invoke(invocation) - return result - - async def __call__(self, argument: Any): - return await self._do_call(argument) - - -class AsyncRpcCallable: - - async def __call__(self, *args, **kwargs): - pass diff --git a/dubbo/callable/rpc_callable_factory.py b/dubbo/callable/rpc_callable_factory.py deleted file mode 100644 index 55edbba..0000000 --- a/dubbo/callable/rpc_callable_factory.py +++ /dev/null @@ -1,37 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dubbo.callable.rpc_callable import RpcCallable -from dubbo.common.url import URL -from dubbo.protocol.invoker import Invoker - - -class RpcCallableFactory: - - def get_proxy(self, url: URL, invoker: Invoker) -> RpcCallable: - """ - Get the callable object. - Args: - url (URL): The URL. - invoker (Invoker): The invoker object. - """ - raise NotImplementedError("get_proxy() is not implemented") - - -class DefaultRpcCallableFactory(RpcCallableFactory): - - def get_proxy(self, url: URL, invoker: Invoker) -> RpcCallable: - pass diff --git a/dubbo/client/client.py b/dubbo/client/client.py index f929029..ecefa8d 100644 --- a/dubbo/client/client.py +++ b/dubbo/client/client.py @@ -13,17 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Optional -from dubbo.callable.rpc_callable import AsyncRpcCallable, RpcCallable -from dubbo.callable.rpc_callable_factory import DefaultRpcCallableFactory -from dubbo.common.constants import common_constants -from dubbo.common.constants.type_constants import ( - DeserializingFunction, - SerializingFunction, -) -from dubbo.common.url import URL +from dubbo.callable import RpcCallable from dubbo.config import ConsumerConfig, ReferenceConfig +from dubbo.constants import common_constants +from dubbo.constants.type_constants import (DeserializingFunction, + SerializingFunction) from dubbo.logger.logger_factory import loggerFactory logger = loggerFactory.get_logger(__name__) @@ -31,9 +27,6 @@ class Client: - _consumer: ConsumerConfig - _reference: ReferenceConfig - __slots__ = ["_consumer", "_reference"] def __init__( @@ -45,66 +38,66 @@ def __init__( def unary( self, method_name: str, - req_serializer: Optional[SerializingFunction] = None, - resp_deserializer: Optional[DeserializingFunction] = None, - ) -> Union[RpcCallable, AsyncRpcCallable]: + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None, + ) -> RpcCallable: return self._callable( - common_constants.CALL_UNARY, method_name, req_serializer, resp_deserializer + common_constants.CALL_UNARY, method_name, request_serializer, response_deserializer ) def client_stream( self, method_name: str, - req_serializer: Optional[SerializingFunction] = None, - resp_deserializer: Optional[DeserializingFunction] = None, - ) -> Union[RpcCallable, AsyncRpcCallable]: + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None, + ) -> RpcCallable: return self._callable( common_constants.CALL_CLIENT_STREAM, method_name, - req_serializer, - resp_deserializer, + request_serializer, + response_deserializer, ) def server_stream( self, method_name: str, - req_serializer: Optional[SerializingFunction] = None, - resp_deserializer: Optional[DeserializingFunction] = None, - ) -> Union[RpcCallable, AsyncRpcCallable]: + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None, + ) -> RpcCallable: return self._callable( common_constants.CALL_SERVER_STREAM, method_name, - req_serializer, - resp_deserializer, + request_serializer, + response_deserializer, ) def bidi_stream( self, method_name: str, - req_serializer: Optional[SerializingFunction] = None, - resp_deserializer: Optional[DeserializingFunction] = None, - ) -> Union[RpcCallable, AsyncRpcCallable]: + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None, + ) -> RpcCallable: return self._callable( common_constants.CALL_BIDI_STREAM, method_name, - req_serializer, - resp_deserializer, + request_serializer, + response_deserializer, ) def _callable( self, call_type: str, method_name: str, - req_serializer: Optional[SerializingFunction] = None, - resp_deserializer: Optional[DeserializingFunction] = None, - ) -> Union[RpcCallable, AsyncRpcCallable]: + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None, + ) -> RpcCallable: """ Generate a callable for the given method Args: call_type: call type method_name: method name - req_serializer: request serializer, args: Any, return: bytes - resp_deserializer: response deserializer, args: bytes, return: Any + request_serializer: request serializer, args: Any, return: bytes + response_deserializer: response deserializer, args: bytes, return: Any Returns: RpcCallable: The callable object """ @@ -112,22 +105,12 @@ def _callable( invoker = self._reference.get_invoker() url = invoker.get_url() - method_url = URL( - method_name, - common_constants.LOCALHOST_KEY, - parameters={ - common_constants.METHOD_KEY: method_name, - common_constants.TYPE_CALL: call_type, - }, - ) - # add attributes - method_url.add_attribute(common_constants.SERIALIZATION, req_serializer) - method_url.add_attribute(common_constants.DESERIALIZATION, resp_deserializer) - - # put the method url into the invoker url - url.add_attribute(method_name, method_url) + # clone url + url = url.clone() + url.add_parameter(common_constants.METHOD_KEY, method_name) + url.add_parameter(common_constants.CALL_KEY, call_type) + url.add_attribute(common_constants.SERIALIZATION, request_serializer) + url.add_attribute(common_constants.DESERIALIZATION, response_deserializer) # create callable - rpc_callable = DefaultRpcCallableFactory().get_proxy(invoker, url) - - return rpc_callable + return RpcCallable(invoker, url) diff --git a/dubbo/common/__init__.py b/dubbo/common/__init__.py deleted file mode 100644 index bcba37a..0000000 --- a/dubbo/common/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/dubbo/common/constants/__init__.py b/dubbo/common/constants/__init__.py deleted file mode 100644 index bcba37a..0000000 --- a/dubbo/common/constants/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/dubbo/compressor/compressor.py b/dubbo/compressor/compressor.py index 2edbc85..602a35b 100644 --- a/dubbo/compressor/compressor.py +++ b/dubbo/compressor/compressor.py @@ -15,7 +15,27 @@ # limitations under the License. +class Compressor: + + def compress(self, data: bytes) -> bytes: + """ + Compress the data + Args: + data (bytes): Data to compress + Returns: + bytes: Compressed data + """ + raise NotImplementedError("compress() is not implemented.") + + class DeCompressor: def decompress(self, data: bytes) -> bytes: - pass + """ + Decompress the data + Args: + data (bytes): Data to decompress + Returns: + bytes: Decompressed data + """ + raise NotImplementedError("decompress() is not implemented.") diff --git a/dubbo/config/logger_config.py b/dubbo/config/logger_config.py index d91d5ba..dfdf8ab 100644 --- a/dubbo/config/logger_config.py +++ b/dubbo/config/logger_config.py @@ -16,12 +16,12 @@ from dataclasses import dataclass from typing import Dict, Optional -from dubbo.common.constants import logger_constants as logger_constants -from dubbo.common.constants.logger_constants import FileRotateType, Level -from dubbo.common.url import URL +from dubbo.constants import logger_constants as logger_constants +from dubbo.constants.logger_constants import FileRotateType, Level from dubbo.extension import extensionLoader from dubbo.logger import LoggerAdapter from dubbo.logger.logger_factory import loggerFactory +from dubbo.url import URL @dataclass @@ -123,7 +123,7 @@ def get_url(self) -> URL: **self._file_config.dict(), } - return URL(protocol=self._driver, host=self._level.value, parameters=parameters) + return URL(scheme=self._driver, host=self._level.value, parameters=parameters) def init(self): # get logger_adapter and initialize loggerFactory diff --git a/dubbo/config/reference_config.py b/dubbo/config/reference_config.py index fd30d8a..3015f50 100644 --- a/dubbo/config/reference_config.py +++ b/dubbo/config/reference_config.py @@ -16,12 +16,11 @@ import threading from typing import List, Optional -from dubbo.callable.rpc_callable_factory import RpcCallableFactory -from dubbo.common.url import URL from dubbo.config.method_config import MethodConfig from dubbo.extension import extensionLoader from dubbo.protocol.invoker import Invoker from dubbo.protocol.protocol import Protocol +from dubbo.url import URL class ReferenceConfig: @@ -37,12 +36,10 @@ class ReferenceConfig: _destroyed: bool _protocol_ins: Optional[Protocol] _invoker: Optional[Invoker] - _proxy_factory: Optional[RpcCallableFactory] def __init__( self, interface_name: str, - check: bool, url: str, protocol: str, methods: Optional[List[MethodConfig]] = None, @@ -55,6 +52,8 @@ def __init__( self._protocol = protocol self._methods = methods or [] + self._invoker = None + def get_invoker(self): if not self._invoker: self._do_init() @@ -66,9 +65,12 @@ def _do_init(self): return clazz = extensionLoader.get_extension(Protocol, self._protocol) - self._protocol_ins = clazz() + # TODO set real URL + self._protocol_ins = clazz(URL.value_of(self._url)) self._create_invoker() self._initialized = True def _create_invoker(self): - self._invoker = self._protocol_ins.refer(URL.value_of(self._url)) + url = URL.value_of(self._url) + url.path = self._interface_name + self._invoker = self._protocol_ins.refer(url) diff --git a/dubbo/callable/__init__.py b/dubbo/constants/__init__.py similarity index 100% rename from dubbo/callable/__init__.py rename to dubbo/constants/__init__.py diff --git a/dubbo/common/constants/common_constants.py b/dubbo/constants/common_constants.py similarity index 78% rename from dubbo/common/constants/common_constants.py rename to dubbo/constants/common_constants.py index c985045..ebf4a96 100644 --- a/dubbo/common/constants/common_constants.py +++ b/dubbo/constants/common_constants.py @@ -20,7 +20,7 @@ LOCALHOST_KEY = "localhost" LOCALHOST_VALUE = "127.0.0.1" -TYPE_CALL = "call" +CALL_KEY = "call" CALL_UNARY = "unary" CALL_CLIENT_STREAM = "client-stream" CALL_SERVER_STREAM = "server-stream" @@ -28,10 +28,19 @@ SERIALIZATION = "serialization" DESERIALIZATION = "deserialization" +COMPRESSOR_KEY = "compressor" +DECOMPRESSOR_KEY = "decompressor" SERVER_KEY = "server" METHOD_KEY = "method" - TRUE_VALUE = "true" FALSE_VALUE = "false" + + +# Constants about the transporter. +TRANSPORTER_KEY = "transporter" +TRANSPORTER_SIDE_KEY = "transporter-side" +TRANSPORTER_SIDE_SERVER = "server" +TRANSPORTER_SIDE_CLIENT = "client" +TRANSPORTER_ON_CONN_CLOSE_KEY = "on-conn-close" diff --git a/dubbo/common/constants/logger_constants.py b/dubbo/constants/logger_constants.py similarity index 100% rename from dubbo/common/constants/logger_constants.py rename to dubbo/constants/logger_constants.py diff --git a/dubbo/common/constants/type_constants.py b/dubbo/constants/type_constants.py similarity index 100% rename from dubbo/common/constants/type_constants.py rename to dubbo/constants/type_constants.py diff --git a/dubbo/extension/__init__.py b/dubbo/extension/__init__.py index 0da2118..8744a34 100644 --- a/dubbo/extension/__init__.py +++ b/dubbo/extension/__init__.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dubbo.extension.extension_loader import ExtensionLoader as _ExtensionLoader +from dubbo.extension.extension_loader import \ + ExtensionLoader as _ExtensionLoader extensionLoader = _ExtensionLoader() diff --git a/dubbo/extension/registry.py b/dubbo/extension/registry.py index c0d0b12..71904b7 100644 --- a/dubbo/extension/registry.py +++ b/dubbo/extension/registry.py @@ -16,9 +16,11 @@ import inspect import sys from dataclasses import dataclass -from typing import Any, Protocol +from typing import Any from dubbo.logger import LoggerAdapter +from dubbo.protocol.protocol import Protocol +from dubbo.remoting.transporter import Transporter @dataclass @@ -38,10 +40,19 @@ class ExtendedRegistry: protocolRegistry = ExtendedRegistry( interface=Protocol, impls={ - "tri": "dubbo.protocol.triple.triple_protocol.TripleProtocol", + "tri": "dubbo.protocol.triple.tri_protocol.TripleProtocol", }, ) +"""Transporter registry.""" +transporterRegistry = ExtendedRegistry( + interface=Transporter, + impls={ + "aio": "dubbo.remoting.aio.aio_transporter.AioTransporter", + }, +) + + """LoggerAdapter registry.""" loggerAdapterRegistry = ExtendedRegistry( interface=LoggerAdapter, diff --git a/dubbo/logger/logger.py b/dubbo/logger/logger.py index 11f3595..00607a8 100644 --- a/dubbo/logger/logger.py +++ b/dubbo/logger/logger.py @@ -15,8 +15,8 @@ # limitations under the License. from typing import Any -from dubbo.common.constants.logger_constants import Level -from dubbo.common.url import URL +from dubbo.constants.logger_constants import Level +from dubbo.url import URL class Logger: diff --git a/dubbo/logger/logger_factory.py b/dubbo/logger/logger_factory.py index 83024d4..59a291b 100644 --- a/dubbo/logger/logger_factory.py +++ b/dubbo/logger/logger_factory.py @@ -16,15 +16,15 @@ import threading from typing import Dict -from dubbo.common.constants import logger_constants as logger_constants -from dubbo.common.constants.logger_constants import Level -from dubbo.common.url import URL +from dubbo.constants import logger_constants as logger_constants +from dubbo.constants.logger_constants import Level from dubbo.logger.logger import Logger, LoggerAdapter from dubbo.logger.logging.logger_adapter import LoggingLoggerAdapter +from dubbo.url import URL # Default logger config with default values. _default_config = URL( - protocol=logger_constants.DEFAULT_DRIVER_VALUE, + scheme=logger_constants.DEFAULT_DRIVER_VALUE, host=logger_constants.DEFAULT_LEVEL_VALUE.value, parameters={ logger_constants.DRIVER_KEY: logger_constants.DEFAULT_DRIVER_VALUE, diff --git a/dubbo/logger/logging/logger.py b/dubbo/logger/logging/logger.py index 0a3887a..8fcb929 100644 --- a/dubbo/logger/logging/logger.py +++ b/dubbo/logger/logging/logger.py @@ -17,7 +17,7 @@ import logging from typing import Dict -from dubbo.common.constants.logger_constants import Level +from dubbo.constants.logger_constants import Level from dubbo.logger import Logger # The mapping from the logging level to the logging level. diff --git a/dubbo/logger/logging/logger_adapter.py b/dubbo/logger/logging/logger_adapter.py index e0ce6eb..c8a20ca 100644 --- a/dubbo/logger/logging/logger_adapter.py +++ b/dubbo/logger/logging/logger_adapter.py @@ -20,13 +20,13 @@ from functools import cache from logging import handlers -from dubbo.common.constants import common_constants -from dubbo.common.constants import logger_constants as logger_constants -from dubbo.common.constants.logger_constants import FileRotateType, Level -from dubbo.common.url import URL +from dubbo.constants import common_constants +from dubbo.constants import logger_constants as logger_constants +from dubbo.constants.logger_constants import FileRotateType, Level from dubbo.logger import Logger, LoggerAdapter from dubbo.logger.logging import formatter from dubbo.logger.logging.logger import LoggingLogger +from dubbo.url import URL """This module provides the logging logger implementation. -> logging module""" diff --git a/dubbo/loop/__init__.py b/dubbo/loop/__init__.py deleted file mode 100644 index a7ebe86..0000000 --- a/dubbo/loop/__init__.py +++ /dev/null @@ -1,58 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dubbo.loop.loop_manger import LoopManager as _LoopManager - - -def _try_use_uvloop() -> None: - """ - Use uvloop instead of the default asyncio loop. - """ - import asyncio - import os - - from dubbo.logger.logger_factory import loggerFactory - - logger = loggerFactory.get_logger("try_use_uvloop") - - # Check if the operating system. - if os.name == "nt": - # Windows is not supported. - logger.warning( - "Unable to use uvloop, because it is not supported on your operating system." - ) - return - - # Try import uvloop. - try: - import uvloop - except ImportError: - # uvloop is not available. - logger.warning( - "Unable to use uvloop, because it is not installed. " - "You can install it by running `pip install uvloop`." - ) - return - - # Use uvloop instead of the default asyncio loop. - if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy): - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - - -# Call the function to try to use uvloop. -_try_use_uvloop() - -loopManager = _LoopManager() diff --git a/dubbo/loop/loop_manger.py b/dubbo/loop/loop_manger.py deleted file mode 100644 index 825f2c7..0000000 --- a/dubbo/loop/loop_manger.py +++ /dev/null @@ -1,111 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import threading -from typing import Optional - -from dubbo.logger.logger_factory import loggerFactory - -logger = loggerFactory.get_logger(__name__) - - -def start_loop(loop): - """ - Start the loop. - Args: - loop: The loop to start. - """ - asyncio.set_event_loop(loop) - loop.run_forever() - - -class LoopManager: - """ - Loop manager. - It used to manage the global event loop and therefore designed as a singleton pattern. - Attributes: - _instance: The instance of the loop manager. - _ins_lock: The lock to protect the instance. - _client_initialized: Whether the client is initialized. - _client_destroyed: Whether the client is destroyed. - _client_loop_info: The client info. (thread, loop) - _cli_lock: The lock to protect the client info. - """ - - _instance = None - _ins_lock = threading.Lock() - - # About client - _client_initialized = False - _client_destroyed = False - _client_loop_info = None - _cli_lock = threading.Lock() - - def __new__(cls): - if cls._instance is None: - with cls._ins_lock: - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def _init_client_loop(self): - """ - Initialize the client loop. - return: The client info. (thread, loop) - """ - new_loop = asyncio.new_event_loop() - # Start the loop in a new thread - thread = threading.Thread( - target=start_loop, args=(new_loop,), name="dubbo-client-loop", daemon=True - ) - thread.start() - self._client_loop_info = (thread, new_loop) - self._client_initialized = True - logger.info("The client loop is initialized.") - return self._client_loop_info - - def get_client_loop(self) -> Optional[asyncio.AbstractEventLoop]: - """ - Get the client loop. Lazy initialization. - return: If the client is destroyed, return None. Otherwise, return the client loop. - """ - if self._client_destroyed: - logger.error("The client is destroyed.") - return None - - if not self._client_initialized: - with self._cli_lock: - if not self._client_initialized: - self._init_client_loop() - return self._client_loop_info[1] - - def destroy_client_loop(self) -> None: - """ - Destroy the client. This method can only be called once. - """ - if self._client_destroyed: - logger.info("The client is already destroyed.") - return - - with self._cli_lock: - if not self._client_destroyed: - client_loop_info = self._client_loop_info - # Stop the loop - client_loop_info[1].stop() - # Wait for the loop to stop - client_loop_info[0].join() - self._client_destroyed = True - logger.info("The client is destroyed.") diff --git a/dubbo/common/node.py b/dubbo/node.py similarity index 97% rename from dubbo/common/node.py rename to dubbo/node.py index 71d64df..f63e12b 100644 --- a/dubbo/common/node.py +++ b/dubbo/node.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dubbo.common.url import URL +from dubbo.url import URL class Node: diff --git a/dubbo/protocol/invocation.py b/dubbo/protocol/invocation.py index 4e4a7f6..59f3b03 100644 --- a/dubbo/protocol/invocation.py +++ b/dubbo/protocol/invocation.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Dict, Optional class Invocation: @@ -44,8 +44,8 @@ class RpcInvocation(Invocation): service_name (str): The name of the service. method_name (str): The name of the method. argument (Any): The method argument. - req_serializer (Any): The request serializer. - res_serializer (Any): The response serializer. + attachments (Optional[Dict[str, str]]): Passed to the remote server during RPC call + attributes (Optional[Dict[str, Any]]): Only used on the caller side, will not appear on the wire. """ def __init__( @@ -53,26 +53,63 @@ def __init__( service_name: str, method_name: str, argument: Any, - req_serializer=None, - res_serializer=None, + attachments: Optional[Dict[str, str]] = None, + attributes: Optional[Dict[str, Any]] = None, ): self._service_name = service_name self._method_name = method_name self._argument = argument - self._req_serializer = req_serializer - self._res_serializer = res_serializer + self._attachments = attachments or {} + self._attributes = attributes or {} - def get_service_name(self): + def add_attachment(self, key: str, value: str) -> None: + """ + Add an attachment to the invocation. + Args: + key (str): The key of the attachment. + value (str): The value of the attachment. + """ + self._attachments[key] = value + + def get_attachment(self, key: str) -> Optional[str]: + """ + Get the attachment of the invocation. + Args: + key (str): The key of the attachment. + Returns: + The value of the attachment. If the attachment does not exist, return None. + """ + return self._attachments.get(key, None) + + def add_attribute(self, key: str, value: Any) -> None: + """ + Add an attribute to the invocation. + Args: + key (str): The key of the attribute. + value (Any): The value of the attribute. + """ + self._attributes[key] = value + + def get_attribute(self, key: str) -> Optional[Any]: + """ + Get the attribute of the invocation. + Args: + key (str): The key of the attribute. + Returns: + The value of the attribute. If the attribute does not exist, return None. + """ + return self._attributes.get(key, None) + + def get_service_name(self) -> str: + """ + Get the service name. + Returns: + The service name. + """ return self._service_name - def get_method_name(self): + def get_method_name(self) -> str: return self._method_name - def get_argument(self): + def get_argument(self) -> Any: return self._argument - - def get_req_serializer(self): - return self._req_serializer - - def get_res_serializer(self): - return self._res_serializer diff --git a/dubbo/protocol/invoker.py b/dubbo/protocol/invoker.py index 8d5b64d..763372f 100644 --- a/dubbo/protocol/invoker.py +++ b/dubbo/protocol/invoker.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dubbo.common.node import Node +from dubbo.node import Node from dubbo.protocol.invocation import Invocation from dubbo.protocol.result import Result diff --git a/dubbo/protocol/protocol.py b/dubbo/protocol/protocol.py index 5ae08a0..7de46f1 100644 --- a/dubbo/protocol/protocol.py +++ b/dubbo/protocol/protocol.py @@ -13,8 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dubbo.common.url import URL from dubbo.protocol.invoker import Invoker +from dubbo.url import URL class Protocol: diff --git a/dubbo/protocol/result.py b/dubbo/protocol/result.py index 06b54e1..53d0480 100644 --- a/dubbo/protocol/result.py +++ b/dubbo/protocol/result.py @@ -13,7 +13,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any class Result: - pass + """ + Result of a call + """ + + def set_value(self, value: Any) -> None: + """ + Set the value of the result + Args: + value: Value to set + """ + raise NotImplementedError("set_value() is not implemented.") + + def get_value(self) -> Any: + """ + Get the value of the result + """ + raise NotImplementedError("get_value() is not implemented.") + + def set_exception(self, exception: Exception) -> None: + """ + Set the exception to the result + Args: + exception: Exception to set + """ + raise NotImplementedError("set_exception() is not implemented.") + + def get_exception(self) -> Exception: + """ + Get the exception to the result + """ + raise NotImplementedError("get_exception() is not implemented.") diff --git a/dubbo/protocol/triple/stream.py b/dubbo/protocol/triple/stream.py deleted file mode 100644 index 65264c1..0000000 --- a/dubbo/protocol/triple/stream.py +++ /dev/null @@ -1,119 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Tuple - - -class Stream: - """ - Stream is a bi-directional channel that manipulates the data flow between peers. - Inbound data from remote peer is acquired by Stream.Listener. - Outbound data to remote peer is sent directly by Stream. - """ - - def __init__(self, stream_id: int): - self._stream_id = stream_id - - def send_headers(self, headers: List[Tuple[str, str]]) -> None: - """ - First call: head frame - Second call: trailer frame. - Args: - headers: The headers to send. - """ - raise NotImplementedError("send_headers() is not implemented") - - def send_data(self, data: bytes) -> None: - """ - Send the data frame - Args: - data: The data to send. - """ - raise NotImplementedError("send_data() is not implemented") - - def send_end_stream(self) -> None: - """ - Send the end stream frame -> An empty data frame will be sent (end_stream=True) - """ - raise NotImplementedError("send_completed() is not implemented") - - class Listener: - """ - Listener is the interface that receives the data from the stream. - """ - - def on_headers(self, headers: List[Tuple[str, str]]) -> None: - """ - Called when the header frame is received - Args: - headers: The headers received. - """ - raise NotImplementedError("receive_headers() is not implemented") - - def on_data(self, data: bytes) -> None: - """ - Called when the data frame is received - Args: - data: The data received. - """ - raise NotImplementedError("receive_data() is not implemented") - - def on_complete(self) -> None: - """ - Complete the stream. - """ - raise NotImplementedError("complete() is not implemented") - - -class ClientStream(Stream): - """ - ClientStream is a Stream that is initiated by the client. - """ - - pass - - class Listener(Stream.Listener): - """ - Listener is the interface that receives the data from the stream. - """ - - def on_trailers(self, headers: List[Tuple[str, str]]) -> None: - """ - Called when the trailers frame is received - Args: - headers: The trailers received. - """ - raise NotImplementedError("receive_trailers() is not implemented") - - -class ServerStream(Stream): - """ - ServerStream is a Stream that is initiated by the server. - """ - - def send_trailers(self, trailers: List[Tuple[str, str]]) -> None: - """ - Send the trailers frame - Args: - trailers: The trailers to send. - """ - raise NotImplementedError("send_trailers() is not implemented") - - class Listener(Stream.Listener): - """ - Listener is the interface that receives the data from the stream. - """ - - pass diff --git a/dubbo/protocol/triple/tri_client.py b/dubbo/protocol/triple/tri_client.py new file mode 100644 index 0000000..5240f61 --- /dev/null +++ b/dubbo/protocol/triple/tri_client.py @@ -0,0 +1,196 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import queue +from typing import Any, List, Optional, Tuple + +from dubbo.compressor.compressor import Compressor, DeCompressor +from dubbo.constants import common_constants +from dubbo.constants.common_constants import CALL_CLIENT_STREAM, CALL_UNARY +from dubbo.constants.type_constants import (DeserializingFunction, + SerializingFunction) +from dubbo.extension import extensionLoader +from dubbo.protocol.result import Result +from dubbo.protocol.triple.tri_codec import TriDecoder, TriEncoder +from dubbo.remoting.aio.h2_stream import Stream +from dubbo.url import URL + + +class TriClientCall(Stream.Listener): + + def __init__( + self, + listener: "TriClientCall.Listener", + url: URL, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None, + ): + self._stream: Optional[Stream] = None + self._listener = listener + + # Try to get the compressor and decompressor from the URL + self._compressor = self._decompressor = None + if compressor_str := url.get_parameter(common_constants.COMPRESSOR_KEY): + self._compressor = extensionLoader.get_extension(Compressor, compressor_str) + if decompressor_str := url.get_parameter(common_constants.DECOMPRESSOR_KEY): + self._decompressor = extensionLoader.get_extension( + DeCompressor, decompressor_str + ) + + self._compressed = self._compressor is not None + self._encoder = TriEncoder(self._compressor) + self._request_serializer = request_serializer + + class TriDecoderListener(TriDecoder.Listener): + + def __init__( + self, + _listener: "TriClientCall.Listener", + _response_deserializer: Optional[DeserializingFunction] = None, + ): + self._listener = _listener + self._response_deserializer = _response_deserializer + + def on_message(self, message: bytes): + if self._response_deserializer: + message = self._response_deserializer(message) + self._listener.on_message(message) + + def close(self): + self._listener.on_complete() + + self._response_deserializer = response_deserializer + self._decoder = TriDecoder( + TriDecoderListener(self._listener, self._response_deserializer), + self._decompressor, + ) + + self._header_received = False + self._headers = None + self._trailers = None + + def bind_stream(self, stream: Stream) -> None: + """ + Bind stream + """ + self._stream = stream + + def send_headers(self, headers: List[Tuple[str, str]], last: bool = False) -> None: + """ + Send headers + Args: + headers (List[Tuple[str, str]]): Headers + last (bool): Last frame or not + """ + self._stream.send_headers(headers, end_stream=last) + + def send_message(self, message: Any, last: bool = False) -> None: + """ + Send a message + Args: + message (Any): Message to send + last (bool): Last frame or not + """ + if self._request_serializer: + data = self._request_serializer(message) + elif isinstance(message, bytes): + data = message + else: + raise TypeError("Message must be bytes or serialized by req_serializer") + + # Encode data + frame_payload = self._encoder.encode(data, self._compressed) + # Send data frame + self._stream.send_data(frame_payload, end_stream=last) + + def on_headers(self, headers: List[Tuple[str, str]]) -> None: + if not self._header_received: + self._headers = headers + self._header_received = True + else: + # receive trailers + self._trailers = headers + + def on_data(self, data: bytes) -> None: + self._decoder.decode(data) + + def on_complete(self) -> None: + self._decoder.close() + + def on_reset(self, err_code: int) -> None: + # TODO: handle reset + pass + + class Listener: + + def on_message(self, message: Any) -> None: + """ + Callback when message is received + """ + raise NotImplementedError("on_message() is not implemented") + + def on_complete(self) -> None: + """ + Callback when the stream is complete + """ + raise NotImplementedError("on_complete() is not implemented") + + +class TriResult(Result): + """ + Triple result + """ + + END_SIGNAL = object() + + def __init__(self, call_type: str): + self._call_type = call_type + self._value_queue = queue.Queue() + self._exception = None + + def set_value(self, value: Any) -> None: + self._value_queue.put(value) + if self._call_type in [CALL_UNARY, CALL_CLIENT_STREAM]: + # Notify the caller that the value is ready + self._value_queue.put(self.END_SIGNAL) + + def get_value(self) -> Any: + if self._call_type in [CALL_UNARY, CALL_CLIENT_STREAM]: + return self._get_single_value() + else: + return self._iterating_values() + + def _get_single_value(self) -> Any: + value = self._value_queue.get() + if value is self.END_SIGNAL: + return None + return value + + def _iterating_values(self) -> Any: + while True: + # block until the value is ready + value = self._value_queue.get() + if value is self.END_SIGNAL: + # break the loop when the value is end signal + break + yield value + + def set_exception(self, exception: Exception) -> None: + # close the value queue + self._value_queue.put(None) + self._exception = exception + + def get_exception(self) -> Exception: + return self._exception diff --git a/dubbo/protocol/triple/tri_codec.py b/dubbo/protocol/triple/tri_codec.py new file mode 100644 index 0000000..b0711a7 --- /dev/null +++ b/dubbo/protocol/triple/tri_codec.py @@ -0,0 +1,196 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import struct +from typing import Optional + +from dubbo.compressor.compressor import Compressor, DeCompressor + +""" + gRPC Message Format Diagram + +----------------------+-------------------------+------------------+ + | HTTP Header | gRPC Header | Business Data | + +----------------------+-------------------------+------------------+ + | (variable length) | compressed-flag (1 byte)| data (variable) | + | | message length (4 byte) | | + +----------------------+-------------------------+------------------+ +""" + +HEADER: str = "HEADER" +PAYLOAD: str = "PAYLOAD" + +# About HEADER +HEADER_LENGTH: int = 5 +COMPRESSED_FLAG_MASK: int = 1 +RESERVED_MASK = 0xFE + + +class TriEncoder: + """ + This class is responsible for encoding the gRPC message format, which is composed of a header and payload. + + Args: + compressor (Optional[Compressor]): The compressor to use for compressing the payload. + """ + + HEADER_LENGTH: int = 5 + COMPRESSED_FLAG_MASK: int = 1 + + def __init__(self, compressor: Optional[Compressor]): + self._compressor: Optional[Compressor] = compressor + + def encode(self, message: bytes, compressed: bool = False) -> bytes: + """ + Encode the message into the gRPC message format. + + Args: + message (bytes): The message to encode. + compressed (bool): Whether to compress the message. + Returns: + bytes: The encoded message in gRPC format. + """ + compressed_flag = COMPRESSED_FLAG_MASK if compressed else 0 + if compressed: + # Compress the payload + message = self._compressor.compress(message) + + message_length = len(message) + if message_length > 0xFFFFFFFF: + raise ValueError("Message too large to encode") + + # Create the header + header = struct.pack(">BI", compressed_flag, message_length) + + return header + message + + +class TriDecoder: + """ + This class is responsible for decoding the gRPC message format, which is composed of a header and payload. + + Args: + listener (TriDecoder.Listener): The listener to deliver the decoded payload to. + decompressor (Optional[DeCompressor]): The decompressor to use for decompressing the payload. + """ + + def __init__( + self, + listener: "TriDecoder.Listener", + decompressor: Optional[DeCompressor], + ): + # store data for decoding + self._accumulate = bytearray() + self._listener = listener + self._decompressor = decompressor + + self._state = HEADER + self._required_length = HEADER_LENGTH + + # decode state, if True, the decoder is currently processing a message + self._decoding = False + + # whether the message is compressed + self._compressed = False + + self._closing = False + self._closed = False + + def decode(self, data: bytes): + """ + Process the incoming bytes, decoding the gRPC message and delivering the payload to the listener. + """ + self._accumulate.extend(data) + self._do_decode() + + def close(self): + """ + Close the decoder and listener. + """ + self._closing = True + self._do_decode() + + def _do_decode(self): + """ + Deliver the accumulated bytes to the listener, processing the header and payload as necessary. + """ + if self._decoding: + return + + self._decoding = True + try: + while self._has_enough_bytes(): + if self._state == HEADER: + self._process_header() + elif self._state == PAYLOAD: + self._process_payload() + if self._closing: + if not self._closed: + self._closed = True + self._accumulate = None + self._listener.close() + finally: + self._decoding = False + + def _has_enough_bytes(self): + """ + Check if the accumulated bytes are enough to process the header or payload + """ + return len(self._accumulate) >= self._required_length + + def _process_header(self): + """ + Processes the GRPC compression header which is composed of the compression flag and the outer frame length. + """ + header_bytes = self._accumulate[: self._required_length] + self._accumulate = self._accumulate[self._required_length :] + # Parse the header + compressed_flag = header_bytes[0] + if (compressed_flag & RESERVED_MASK) != 0: + raise ValueError("gRPC frame header malformed: reserved bits not zero") + + self._compressed = bool(compressed_flag & COMPRESSED_FLAG_MASK) + self._required_length = int.from_bytes(header_bytes[1:], byteorder="big") + # Continue to process the payload + self._state = PAYLOAD + + def _process_payload(self): + """ + Processes the GRPC message body, which depending on frame header flags may be compressed. + """ + payload_bytes = self._accumulate[: self._required_length] + self._accumulate = self._accumulate[self._required_length :] + + if self._compressed: + # Decompress the payload + payload_bytes = self._decompressor.decompress(payload_bytes) + + self._listener.on_message(bytes(payload_bytes)) + + # Done with this frame, begin processing the next header. + self._required_length = HEADER_LENGTH + self._state = HEADER + + class Listener: + def on_message(self, message: bytes): + """ + Called when a message is received. + """ + raise NotImplementedError("Listener.on_message() not implemented") + + def close(self): + """ + Called when the listener is closed. + """ + raise NotImplementedError("Listener.close() not implemented") diff --git a/dubbo/protocol/triple/tri_decoder.py b/dubbo/protocol/triple/tri_decoder.py deleted file mode 100644 index 3defcbd..0000000 --- a/dubbo/protocol/triple/tri_decoder.py +++ /dev/null @@ -1,152 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import enum - -from dubbo.compressor.compressor import DeCompressor - - -class GrpcDecodeState(enum.Enum): - """ - gRPC Decode State - """ - - HEADER = 0 - PAYLOAD = 1 - - -class TriDecoder: - """ - This class is responsible for decoding the gRPC message format, which is composed of a header and payload. - gRPC Message Format Diagram - - +----------------------+-------------------------+------------------+ - | HTTP Header | gRPC Header | Business Data | - +----------------------+-------------------------+------------------+ - | (variable length) | type (1 byte) | data (variable) | - | | compressed-flag (1 byte)| | - | | message length (4 byte) | | - +----------------------+-------------------------+------------------+ - - Args: - decompressor (DeCompressor): The decompressor to use for decompressing the payload. - listener (TriDecoder.Listener): The listener to deliver the decoded payload to. - - """ - - HEADER_LENGTH: int = 5 - COMPRESSED_FLAG_MASK: int = 1 - RESERVED_MASK: int = 0xFE - - def __init__(self, decompressor: DeCompressor, listener: "TriDecoder.Listener"): - self.accumulate = bytearray() - self._decompressor = decompressor - self._listener = listener - self.state = GrpcDecodeState.HEADER - self.required_length = self.HEADER_LENGTH - self.compressed = False - self.in_delivery = False - self.closing = False - self.closed = False - - def deframe(self, data: bytes): - """ - Process the incoming bytes, deframing the gRPC message and delivering the payload to the listener. - """ - self.accumulate.extend(data) - self._deliver() - - def close(self): - """ - Close the decoder and listener. - """ - self.closing = True - self._deliver() - - def _deliver(self): - """ - Deliver the accumulated bytes to the listener, processing the header and payload as necessary. - """ - if self.in_delivery: - return - - self.in_delivery = True - try: - while self._has_enough_bytes(): - if self.state == GrpcDecodeState.HEADER: - self._process_header() - elif self.state == GrpcDecodeState.PAYLOAD: - self._process_payload() - if self.closing: - if not self.closed: - self.closed = True - self.accumulate = None - self._listener.close() - finally: - self.in_delivery = False - - def _has_enough_bytes(self): - """ - Check if the accumulated bytes are enough to process the header or payload - """ - return len(self.accumulate) >= self.required_length - - def _process_header(self): - """ - Processes the GRPC compression header which is composed of the compression flag and the outer frame length. - """ - header_bytes = self.accumulate[: self.required_length] - self.accumulate = self.accumulate[self.required_length :] - - type_byte = header_bytes[0] - - if type_byte & self.RESERVED_MASK: - raise ValueError("gRPC frame header malformed: reserved bits not zero") - - self.compressed = bool(type_byte & self.COMPRESSED_FLAG_MASK) - self.required_length = int.from_bytes(header_bytes[1:], byteorder="big") - - # Continue to process the payload - self.state = GrpcDecodeState.PAYLOAD - - def _process_payload(self): - """ - Processes the GRPC message body, which depending on frame header flags may be compressed. - """ - payload_bytes = self.accumulate[: self.required_length] - self.accumulate = self.accumulate[self.required_length :] - - if self.compressed: - # Decompress the payload - payload_bytes = self._decompressor.decompress(payload_bytes) - - self._listener.on_message(payload_bytes) - - # Done with this frame, begin processing the next header. - self.required_length = self.HEADER_LENGTH - self.state = GrpcDecodeState.HEADER - - class Listener: - def on_message(self, message: bytes): - """ - Called when a message is received. - """ - raise NotImplementedError("Listener.on_message() not implemented") - - def close(self): - """ - Called when the listener is closed. - """ - raise NotImplementedError("Listener.close() not implemented") diff --git a/dubbo/protocol/triple/tri_invoker.py b/dubbo/protocol/triple/tri_invoker.py index d2730a8..56f60a9 100644 --- a/dubbo/protocol/triple/tri_invoker.py +++ b/dubbo/protocol/triple/tri_invoker.py @@ -13,25 +13,121 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dubbo.common.url import URL -from dubbo.protocol.invocation import Invocation +from typing import Any, List, Tuple + +from dubbo.constants import common_constants +from dubbo.logger.logger_factory import loggerFactory +from dubbo.protocol.invocation import Invocation, RpcInvocation from dubbo.protocol.invoker import Invoker from dubbo.protocol.result import Result +from dubbo.protocol.triple.tri_client import TriClientCall, TriResult +from dubbo.remoting.aio.h2_stream_handler import StreamHandler +from dubbo.remoting.transporter import Client +from dubbo.url import URL + +logger = loggerFactory.get_logger(__name__) + + +class TriClientCallListener(TriClientCall.Listener): + + def __init__(self, result: TriResult): + self._result = result + + def on_message(self, message: Any) -> None: + # Set the message to the result + self._result.set_value(message) + + def on_complete(self) -> None: + # Set the end signal to the result + self._result.set_value(self._result.END_SIGNAL) + + +class TriInvoker(Invoker): + def __init__(self, url: URL, client: Client, stream_handler: StreamHandler): + self._url = url + self._client = client + self._stream_handler = stream_handler -class TripleInvoker(Invoker): + self._destroyed = False - def __init__(self, url: URL): - self.url = url + def invoke(self, invocation: RpcInvocation) -> Result: + call_type = invocation.get_attribute(common_constants.CALL_KEY) + result = TriResult(call_type) - def invoke(self, invocation: Invocation) -> Result: - pass + # TODO Return an exception result + if self.destroyed: + logger.warning("The invoker has been destroyed.") + raise Exception("The invoker has been destroyed.") + elif not self._client.connected: + pass + + # Create a new TriClientCall object + tri_client_call = TriClientCall( + TriClientCallListener(result), + url=self._url, + request_serializer=invocation.get_attribute(common_constants.SERIALIZATION), + response_deserializer=invocation.get_attribute( + common_constants.DESERIALIZATION + ), + ) + stream = self._stream_handler.create(tri_client_call) + tri_client_call.bind_stream(stream) + + if call_type in ( + common_constants.CALL_UNARY, + common_constants.CALL_SERVER_STREAM, + ): + self._invoke_unary(tri_client_call, invocation) + elif call_type in ( + common_constants.CALL_CLIENT_STREAM, + common_constants.CALL_BIDI_STREAM, + ): + self._invoke_stream(tri_client_call, invocation) + + return result + + def _invoke_unary(self, call: TriClientCall, invocation: Invocation) -> None: + call.send_headers(self._create_headers(invocation)) + call.send_message(invocation.get_argument(), last=True) + + def _invoke_stream(self, call: TriClientCall, invocation: Invocation) -> None: + call.send_headers(self._create_headers(invocation)) + next_message = None + for message in invocation.get_argument(): + if next_message is not None: + call.send_message(next_message, last=False) + next_message = message + call.send_message(next_message, last=True) + + def _create_headers(self, invocation: Invocation) -> List[Tuple[str, str]]: + + headers = [ + (":method", "POST"), + (":authority", self._url.location), + (":scheme", self._url.scheme), + ( + ":path", + f"/{invocation.get_service_name()}/{invocation.get_method_name()}", + ), + ("content-type", "application/grpc+proto"), + ("te", "trailers"), + ] + # TODO Add more headers information + return headers def get_url(self) -> URL: - return self.url + return self._url def is_available(self) -> bool: - pass + return self._client.connected + + @property + def destroyed(self) -> bool: + return self._destroyed def destroy(self) -> None: - pass + self._client.close() + self._client = None + self._stream_handler = None + self._url = None diff --git a/dubbo/protocol/triple/triple_protocol.py b/dubbo/protocol/triple/tri_listener.py similarity index 68% rename from dubbo/protocol/triple/triple_protocol.py rename to dubbo/protocol/triple/tri_listener.py index 445ffef..5f1ab3e 100644 --- a/dubbo/protocol/triple/triple_protocol.py +++ b/dubbo/protocol/triple/tri_listener.py @@ -13,16 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dubbo.common.url import URL -from dubbo.logger.logger_factory import loggerFactory -from dubbo.protocol.invoker import Invoker -from dubbo.protocol.protocol import Protocol +from typing import List, Tuple -logger = loggerFactory.get_logger(__name__) +from dubbo.remoting.aio.h2_stream import Stream -class TripleProtocol(Protocol): +class TriClientStreamListener(Stream.Listener): - def refer(self, url: URL) -> Invoker: + def on_headers(self, headers: List[Tuple[str, str]]) -> None: + pass + + def on_data(self, data: bytes) -> None: + pass + + def on_complete(self) -> None: + pass + def on_reset(self, err_code: int) -> None: pass diff --git a/dubbo/protocol/triple/tri_protocol.py b/dubbo/protocol/triple/tri_protocol.py new file mode 100644 index 0000000..1f9e6e6 --- /dev/null +++ b/dubbo/protocol/triple/tri_protocol.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from concurrent.futures import ThreadPoolExecutor + +from dubbo.constants import common_constants +from dubbo.extension import extensionLoader +from dubbo.logger.logger_factory import loggerFactory +from dubbo.protocol.invoker import Invoker +from dubbo.protocol.protocol import Protocol +from dubbo.protocol.triple.tri_invoker import TriInvoker +from dubbo.remoting.aio.h2_protocol import H2Protocol +from dubbo.remoting.aio.h2_stream_handler import ClientStreamHandler +from dubbo.remoting.transporter import Transporter +from dubbo.url import URL + +logger = loggerFactory.get_logger(__name__) + + +class TripleProtocol(Protocol): + + def __init__(self, url: URL): + self._url = url + self._transporter: Transporter = extensionLoader.get_extension( + Transporter, + self._url.get_parameter(common_constants.TRANSPORTER_KEY) or "aio", + )() + self._invokers = [] + + def refer(self, url: URL) -> Invoker: + """ + Refer a remote service. + Args: + url (URL): The URL of the remote service. + """ + # TODO Simply create it here, then set up a more appropriate configuration that can be configured by the user + executor = ThreadPoolExecutor(thread_name_prefix="dubbo-tri-") + # Create a stream handler + stream_handler = ClientStreamHandler(executor) + url.add_attribute("protocol", H2Protocol) + url.add_attribute("stream_handler", stream_handler) + # Create a client + client = self._transporter.connect(url) + invoker = TriInvoker(url, client, stream_handler) + self._invokers.append(invoker) + return invoker diff --git a/dubbo/protocol/triple/tri_rpc_status.py b/dubbo/protocol/triple/tri_rpc_status.py new file mode 100644 index 0000000..98af7a5 --- /dev/null +++ b/dubbo/protocol/triple/tri_rpc_status.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum + + +class TriRpcCode(enum.Enum): + """ + See https://github.com/grpc/grpc/blob/master/doc/statuscodes.md + """ + + # Not an error; returned on success. + OK = 0 + # The operation was cancelled, typically by the caller. + CANCELLED = 1 + # Unknown error. + UNKNOWN = 2 + # The client specified an invalid argument. + INVALID_ARGUMENT = 3 + # The deadline expired before the operation could complete. + DEADLINE_EXCEEDED = 4 + # Some requested entity (e.g., file or directory) was not found + NOT_FOUND = 5 + # The entity that a client attempted to create (e.g., file or directory) already exists. + ALREADY_EXISTS = 6 + # The caller does not have permission to execute the specified operation. + PERMISSION_DENIED = 7 + # Some resource has been exhausted, perhaps a per-user quota, or perhaps the entire file system is out of space. + RESOURCE_EXHAUSTED = 8 + # The operation was rejected because the system is not in a state required for the operation's execution. + FAILED_PRECONDITION = 9 + # The operation was aborted, typically due to a concurrency issue such as a sequencer check failure or transaction abort. + ABORTED = 10 + # The operation was attempted past the valid range. + OUT_OF_RANGE = 11 + # The operation is not implemented or is not supported/enabled in this service. + UNIMPLEMENTED = 12 + # Internal errors. + INTERNAL = 13 + # The service is currently unavailable. + UNAVAILABLE = 14 + # Unrecoverable data loss or corruption. + DATA_LOSS = 15 + # The request does not have valid authentication credentials for the operation. + UNAUTHENTICATED = 16 diff --git a/dubbo/remoting/aio/aio_transporter.py b/dubbo/remoting/aio/aio_transporter.py index d684434..1e6e128 100644 --- a/dubbo/remoting/aio/aio_transporter.py +++ b/dubbo/remoting/aio/aio_transporter.py @@ -13,37 +13,149 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import threading +import uuid +from typing import Optional, Tuple -from dubbo.common.url import URL +from dubbo.constants import common_constants from dubbo.logger.logger_factory import loggerFactory -from dubbo.remoting.transporter import RemotingClient, RemotingServer, Transporter +from dubbo.remoting.aio import loop +from dubbo.remoting.transporter import Client, Server, Transporter +from dubbo.url import URL logger = loggerFactory.get_logger(__name__) -class AioTransporter(Transporter): +class AioClient(Client): """ - Asyncio transporter. + Asyncio client. + Args: + url(URL): The configuration of the client. """ - def bind(self, url: URL) -> RemotingServer: - pass + def __init__(self, url: URL): + super().__init__(url) + + # Set the side of the transporter to client. + self._url.add_parameter( + common_constants.TRANSPORTER_SIDE_KEY, + common_constants.TRANSPORTER_SIDE_CLIENT, + ) + + # Set connection closed function + def _connection_lost(exc: Optional[Exception]) -> None: + if exc: + logger.error("Connection lost", exc) + self._connected = False + + self._url.add_attribute( + common_constants.TRANSPORTER_ON_CONN_CLOSE_KEY, _connection_lost + ) + + self._thread: Optional[threading.Thread] = None + self._loop: Optional[asyncio.AbstractEventLoop] = None + + self._transport: Optional[asyncio.Transport] = None + self._protocol: Optional[asyncio.Protocol] = None + + self._closing = False + + # Open and connect the client + self.open() + self.connect() + + def open(self) -> None: + """ + Create a thread and run asyncio loop in it. + """ + self._loop, self._thread = loop.start_loop_in_thread( + f"dubbo-aio-client-{uuid.uuid4()}" + ) + self._opened = True + + def _create_protocol(self) -> asyncio.Protocol: + """ + Create the protocol. + """ + + return self._url.attributes["protocol"](self._url) - def connect(self, url: URL) -> RemotingClient: - pass + def connect(self) -> None: + """ + Connect to the server. + """ + if not self._opened: + raise RuntimeError("The client is not opened yet.") + elif self._closed: + raise RuntimeError("The client is closed.") + async def _inner_connect() -> Tuple[asyncio.Transport, asyncio.Protocol]: + running_loop = asyncio.get_running_loop() -class AioClient(RemotingClient): + transport, protocol = await running_loop.create_connection( + lambda: self._url.get_attribute("protocol")(self._url), + self._url.host, + self._url.port, + ) + return transport, protocol + + future = asyncio.run_coroutine_threadsafe(_inner_connect(), self._loop) + + try: + self._transport, self._protocol = future.result() + self._connected = True + logger.info( + f"Connected to the server: ip={self._url.host}, port={self._url.port}" + ) + except Exception as e: + logger.error(f"Failed to connect to the server: {e}") + raise e + + def close(self) -> None: + """ + Close the client. just stop the transport. + """ + if not self._opened: + raise RuntimeError("The client is not opened yet.") + if self._closing or self._closed: + return + + self._closing = True + + try: + # Close the transport + self._transport.close() + self._connected = False + # Stop the loop + loop.stop_loop_in_thread(self._loop, self._thread) + self._closed = True + finally: + self._closing = False + + +class AioServer(Server): """ - Asyncio client. + Asyncio server. """ - pass + def __init__(self, url: URL): + self._url = url + # Set the side of the transporter to server. + self._url.add_parameter( + common_constants.TRANSPORTER_SIDE_KEY, + common_constants.TRANSPORTER_SIDE_SERVER, + ) + # TODO implement the server -class AioServer(RemotingServer): +class AioTransporter(Transporter): """ - Asyncio server. + Asyncio transporter. """ - pass + def connect(self, url: URL) -> Client: + return AioClient(url) + + def bind(self, url: URL) -> Server: + return AioServer(url) diff --git a/dubbo/remoting/aio/constants.py b/dubbo/remoting/aio/constants.py deleted file mode 100644 index cbcc52c..0000000 --- a/dubbo/remoting/aio/constants.py +++ /dev/null @@ -1,18 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Used to indicate the end of the data. -END_DATA_SENTINEL = object() diff --git a/dubbo/remoting/aio/h2_frame.py b/dubbo/remoting/aio/h2_frame.py index af3f0d5..0cdc022 100644 --- a/dubbo/remoting/aio/h2_frame.py +++ b/dubbo/remoting/aio/h2_frame.py @@ -18,15 +18,8 @@ import time from typing import Any, Dict, Optional -from h2.events import ( - DataReceived, - Event, - RequestReceived, - ResponseReceived, - StreamReset, - TrailersReceived, - WindowUpdated, -) +from h2.events import (DataReceived, Event, RequestReceived, ResponseReceived, + StreamReset, TrailersReceived, WindowUpdated) class H2FrameType(enum.Enum): diff --git a/dubbo/remoting/aio/h2_protocol.py b/dubbo/remoting/aio/h2_protocol.py index 1707f7c..dd1c73f 100644 --- a/dubbo/remoting/aio/h2_protocol.py +++ b/dubbo/remoting/aio/h2_protocol.py @@ -14,14 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +import threading +from concurrent.futures import Future as ThreadingFuture from typing import Dict, Optional, Tuple from h2.config import H2Configuration from h2.connection import H2Connection +from dubbo.constants import common_constants from dubbo.logger.logger_factory import loggerFactory from dubbo.remoting.aio.h2_frame import H2Frame, H2FrameType, H2FrameUtils from dubbo.remoting.aio.h2_stream_handler import StreamHandler +from dubbo.url import URL logger = loggerFactory.get_logger(__name__) @@ -198,13 +202,20 @@ class H2Protocol(asyncio.Protocol): It handles connection state, stream mapping, and data flow control. Args: - h2_config (H2Configuration): The configuration for the H2 connection. - stream_handler (StreamHandler): The handler for managing streams. - + url (URL): The URL object that contains the connection parameters. """ - def __init__(self, h2_config: H2Configuration, stream_handler: StreamHandler): + def __init__(self, url: URL): + self.url = url # Create the H2 state machine + client_side = ( + self.url.parameters.get( + common_constants.TRANSPORTER_SIDE_KEY, + common_constants.TRANSPORTER_SIDE_CLIENT, + ) + == common_constants.TRANSPORTER_SIDE_CLIENT + ) + h2_config = H2Configuration(client_side=client_side, header_encoding="utf-8") self.conn: H2Connection = H2Connection(config=h2_config) # the backing transport. @@ -214,7 +225,7 @@ def __init__(self, h2_config: H2Configuration, stream_handler: StreamHandler): self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() # A mapping of stream ID to stream object. - self._stream_handler: StreamHandler = stream_handler + self._stream_handler: StreamHandler = self.url.attributes["stream_handler"] self._data_follow_control: Optional[DataFlowControl] = None @@ -246,6 +257,19 @@ def connection_lost(self, exc) -> None: self._stream_handler.destroy() self._data_follow_control.cancel() + # Handle the connection close event + if on_conn_lost := self.url.attributes.get( + common_constants.TRANSPORTER_ON_CONN_CLOSE_KEY + ): + if isinstance(on_conn_lost, (asyncio.Event, threading.Event)): + on_conn_lost.set() + elif isinstance(on_conn_lost, (asyncio.Future, ThreadingFuture)): + on_conn_lost.set_result(exc) + elif callable(on_conn_lost): + on_conn_lost(exc) + else: + logger.error("Unable to handle the connection close event") + def send_headers_frame(self, headers_frame: H2Frame) -> asyncio.Event: """ Send headers to the remote peer. (thread-safe) @@ -258,9 +282,9 @@ def send_headers_frame(self, headers_frame: H2Frame) -> asyncio.Event: """ headers_event = asyncio.Event() - def _inner_send_headers_frame(headers_frame: H2Frame, event: asyncio.Event): + def _inner_send_headers_frame(_headers_frame: H2Frame, event: asyncio.Event): self.conn.send_headers( - headers_frame.stream_id, headers_frame.data, headers_frame.end_stream + _headers_frame.stream_id, _headers_frame.data, _headers_frame.end_stream ) self.transport.write(self.conn.data_to_send()) # Set the event to indicate that the headers frame has been sent. @@ -316,8 +340,8 @@ def data_received(self, data: bytes) -> None: frame = H2FrameUtils.create_frame_by_event(event) if not frame: # If frame is None, there are two possible cases: - # 1. Events that are handled automatically by the H2 library. -> We just need to send it. - # e.g. RemoteSettingsChanged, PingReceived + # 1. Events that are handled automatically by the H2 library (e.g. RemoteSettingsChanged, PingReceived). + # -> We just need to send it. # 2. Events that are not implemented or do not require attention. -> We'll ignore it for now. pass else: @@ -326,6 +350,9 @@ def data_received(self, data: bytes) -> None: # Update the flow control window self._data_follow_control.release(frame) else: + if frame.frame_type == H2FrameType.RST_STREAM: + # Reset the stream + self._data_follow_control.reset(frame) # Handle the frame self._stream_handler.handle_frame(frame) diff --git a/dubbo/remoting/aio/h2_stream.py b/dubbo/remoting/aio/h2_stream.py index 5880fee..05deadd 100644 --- a/dubbo/remoting/aio/h2_stream.py +++ b/dubbo/remoting/aio/h2_stream.py @@ -17,12 +17,8 @@ from typing import List, Optional, Tuple from dubbo.logger.logger_factory import loggerFactory -from dubbo.remoting.aio.h2_frame import ( - DATA_COMPLETED_FRAME, - H2Frame, - H2FrameType, - H2FrameUtils, -) +from dubbo.remoting.aio.h2_frame import (DATA_COMPLETED_FRAME, H2Frame, + H2FrameType, H2FrameUtils) logger = loggerFactory.get_logger(__name__) @@ -220,20 +216,29 @@ class Stream: Args: stream_id (int): The stream identifier. - protocol (H2Protocol): The protocol instance used to send frames. + listener (Stream.Listener): The listener for the stream to handle the received frames. loop (asyncio.AbstractEventLoop): The asyncio event loop. + protocol (H2Protocol): The protocol instance used to send frames. """ - def __init__(self, stream_id: int, protocol, loop: asyncio.AbstractEventLoop): + def __init__( + self, + stream_id: int, + listener: "Stream.Listener", + loop: asyncio.AbstractEventLoop, + protocol, + ): # import here to avoid circular import from dubbo.remoting.aio.h2_protocol import H2Protocol - # The protocol. - self._protocol: H2Protocol = protocol - # The stream ID. self._stream_id: int = stream_id + # The listener for the stream to handle the received frames. + self._listener: "Stream.Listener" = listener + + # The protocol. + self._protocol: H2Protocol = protocol # The asyncio event loop. self._loop = loop @@ -268,17 +273,10 @@ def _inner_send_headers(_headers: List[Tuple[str, str]], _end_stream: bool): self._stream_id, _headers, _end_stream ) self._stream_frame_control.put_headers(headers_frame) - if end_stream: - # The data is completed. - self._stream_frame_control.put_data(DATA_COMPLETED_FRAME) self._loop.call_soon_threadsafe(_inner_send_headers, headers, end_stream) - - def close(self) -> None: - """ - Close the stream by cancelling the frame sender loop. - """ - self._stream_frame_control.cancel() + # Try to close the stream + self.try_close() def send_data(self, data: bytes, end_stream: bool = False) -> None: """ @@ -289,7 +287,6 @@ def send_data(self, data: bytes, end_stream: bool = False) -> None: end_stream (bool): Whether to end the stream after sending this frame. """ if self._send_completed: - logger.info("Send completed.") return else: self._send_completed = end_stream @@ -301,6 +298,18 @@ def _inner_send_data(_data: bytes, _end_stream: bool): self._stream_frame_control.put_data(data_frame) self._loop.call_soon_threadsafe(_inner_send_data, data, end_stream) + # Try to close the stream + self.try_close() + + def send_data_completed(self) -> None: + """ + Indicates that the data frame has been fully sent, but other frames (such as trailers) may still need to be sent. + """ + + def _inner_send_data_completed(): + self._stream_frame_control.put_data(DATA_COMPLETED_FRAME) + + self._loop.call_soon_threadsafe(_inner_send_data_completed) def send_reset(self, error_code: int) -> None: """ @@ -322,6 +331,9 @@ def _inner_send_reset(_error_code: int): self._loop.call_soon_threadsafe(_inner_send_reset, error_code) + # Close the stream immediately. + self.close() + def receive_headers(self, headers: List[Tuple[str, str]]) -> None: """ Called when a headers frame is received. @@ -329,7 +341,7 @@ def receive_headers(self, headers: List[Tuple[str, str]]) -> None: Args: headers (List[Tuple[str, str]]): The headers received. """ - raise NotImplementedError("receive_headers() is not implemented") + self._listener.on_headers(headers) def receive_data(self, data: bytes) -> None: """ @@ -338,29 +350,74 @@ def receive_data(self, data: bytes) -> None: Args: data (bytes): The data received. """ - raise NotImplementedError("receive_data() is not implemented") + self._listener.on_data(data) def receive_complete(self) -> None: """ Called when the stream is completed. """ self._receive_completed = True + # notify the listener + self._listener.on_complete() + # Try to close the stream + self.try_close() - def cancel_by_remote(self, err_code: int) -> None: + def receive_reset(self, err_code: int) -> None: """ Called when the stream is cancelled by the remote peer. Args: err_code (int): The error code indicating the reason for cancellation. """ - raise NotImplementedError("cancel_by_remote() is not implemented") + self._listener.on_reset(err_code) + def try_close(self) -> None: + """ + Try to close the stream. + """ + if self._send_completed and self._receive_completed: + self.close() -class ClientStream(Stream): - # TODO implement the ClientStream - pass + def close(self) -> None: + """ + Close the stream by cancelling the frame sender loop. + """ + self._stream_frame_control.cancel() + class Listener: + """ + The listener for the stream to handle the received frames. + """ -class ServerStream(Stream): - # TODO implement the ServerStream - pass + def on_headers(self, headers: List[Tuple[str, str]]) -> None: + """ + Called when a headers frame is received. + + Args: + headers (List[Tuple[str, str]]): The headers received. + """ + raise NotImplementedError("on_headers() is not implemented") + + def on_data(self, data: bytes) -> None: + """ + Called when a data frame is received. + + Args: + data (bytes): The data received. + """ + raise NotImplementedError("on_data() is not implemented") + + def on_complete(self) -> None: + """ + Called when the stream is completed. + """ + raise NotImplementedError("on_complete() is not implemented") + + def on_reset(self, err_code: int) -> None: + """ + Called when the stream is cancelled by the remote peer. + + Args: + err_code (int): The error code indicating the reason for cancellation. + """ + raise NotImplementedError("on_reset() is not implemented") diff --git a/dubbo/remoting/aio/h2_stream_handler.py b/dubbo/remoting/aio/h2_stream_handler.py index 257bcfc..9142eb9 100644 --- a/dubbo/remoting/aio/h2_stream_handler.py +++ b/dubbo/remoting/aio/h2_stream_handler.py @@ -16,11 +16,11 @@ import asyncio from concurrent.futures import Future as ThreadingFuture from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Optional +from typing import Dict, Optional, Tuple from dubbo.logger.logger_factory import loggerFactory from dubbo.remoting.aio.h2_frame import H2Frame, H2FrameType -from dubbo.remoting.aio.h2_stream import ClientStream, ServerStream, Stream +from dubbo.remoting.aio.h2_stream import Stream logger = loggerFactory.get_logger(__name__) @@ -42,7 +42,7 @@ def __init__( self._protocol: Optional[H2Protocol] = None # The event loop to run the asynchronous function. - self._loop: Optional[asyncio.AbstractEventLoop] = asyncio.get_event_loop() + self._loop: Optional[asyncio.AbstractEventLoop] = None # The streams managed by the handler self._streams: Dict[int, Stream] = {} @@ -59,6 +59,7 @@ def init(self, loop: asyncio.AbstractEventLoop, protocol) -> None: """ self._loop = loop self._protocol = protocol + self._streams.clear() def handle_frame(self, frame: H2Frame) -> None: """ @@ -87,18 +88,20 @@ def _handle_in_executor(self, frame: H2Frame) -> None: elif frame_type == H2FrameType.DATA: stream.receive_data(frame.data) elif frame_type == H2FrameType.RST_STREAM: - stream.cancel_by_remote(frame.data) + stream.receive_reset(frame.data) else: logger.debug(f"Unhandled frame: {frame_type}") if frame.end_stream: stream.receive_complete() - def create(self) -> Stream: + def create(self, listener: Stream.Listener) -> Stream: """ Create a new stream. -> Client + Args: + listener: The listener to the stream. Returns: - Stream: The stream object. + Stream: The new stream. """ raise NotImplementedError("create() is not implemented") @@ -129,33 +132,44 @@ def destroy(self) -> None: class ClientStreamHandler(StreamHandler): - def create(self) -> Stream: + def create(self, listener: Stream.Listener) -> Stream: """ Create a new stream. -> Client + Args: + listener: The listener to the stream. + Returns: + Stream: The new stream. """ # Create a new client stream future = ThreadingFuture() - def _inner_create(future: ThreadingFuture): + def _inner_create(_future: ThreadingFuture): new_stream_id = self._protocol.conn.get_next_available_stream_id() - new_stream = ClientStream(new_stream_id, self._protocol, self._loop) + new_stream = Stream(new_stream_id, listener, self._loop, self._protocol) self._streams[new_stream_id] = new_stream - future.set_result(new_stream) + _future.set_result(new_stream) self._loop.call_soon_threadsafe(_inner_create, future) + # Return the stream and the listener return future.result() - # TODO implement ClientStreamHandler... - class ServerStreamHandler(StreamHandler): - def register(self, stream_id: int) -> None: + def register(self, stream_id: int) -> Tuple[Stream, Stream.Listener]: """ Register the stream to the handler -> Server + Args: + stream_id: The stream ID. + Returns: + (Stream, Stream.Listener): A tuple containing the stream and the listener. """ - new_stream = ServerStream(stream_id, self._protocol, self._loop) + # TODO Create a new listener + new_listener = Stream.Listener() + new_stream = Stream(stream_id, new_listener, self._loop, self._protocol) self._streams[stream_id] = new_stream + # Return the stream and the listener + return new_stream, new_listener def handle_frame(self, frame: H2Frame) -> None: # Register the stream if it is a HEADERS frame and the stream is not registered. @@ -165,5 +179,3 @@ def handle_frame(self, frame: H2Frame) -> None: ): self.register(frame.stream_id) super().handle_frame(frame) - - # TODO implement ServerStreamHandler... diff --git a/dubbo/remoting/aio/loop.py b/dubbo/remoting/aio/loop.py new file mode 100644 index 0000000..503432e --- /dev/null +++ b/dubbo/remoting/aio/loop.py @@ -0,0 +1,150 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import threading +from typing import Optional, Tuple + +from dubbo.logger.logger_factory import loggerFactory + +logger = loggerFactory.get_logger(__name__) + + +def start_loop(running_loop: asyncio.AbstractEventLoop) -> None: + """ + Start the running_loop. + Args: + running_loop: The running_loop to start. + """ + asyncio.set_event_loop(running_loop) + running_loop.run_forever() + + +async def _stop_loop( + running_loop: Optional[asyncio.AbstractEventLoop] = None, + signal: Optional[threading.Event] = None, +) -> None: + """ + Real function to stop the running_loop. + Args: + running_loop: The running_loop to stop. If None, the current running_loop will be stopped. + signal: The future to set the result. + """ + running_loop = running_loop or asyncio.get_running_loop() + # Cancel all tasks + tasks = [ + task for task in asyncio.all_tasks(running_loop) if task is not asyncio.current_task() + ] + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + # Stop the event running_loop + running_loop.stop() + if signal: + # Set the result of the future + signal.set() + + +def stop_loop(running_loop: Optional[asyncio.AbstractEventLoop] = None, wait: bool = False): + """ + Stop the running_loop. It will cancel all tasks and stop the running_loop.(thread-safe) + Args: + running_loop: The running_loop to stop. If None, the current running_loop will be stopped. + wait: Whether to wait for the running_loop to stop. + """ + running_loop = running_loop or asyncio.get_running_loop() + # Create a future to wait for the running_loop to stop + signal = threading.Event() + # Call the asynchronous function to stop the running_loop + asyncio.run_coroutine_threadsafe(_stop_loop(signal=signal), running_loop) + if wait: + # Wait for the running_loop to stop + signal.wait() + + +def start_loop_in_thread( + thread_name: str, running_loop: Optional[asyncio.AbstractEventLoop] = None +) -> Tuple[asyncio.AbstractEventLoop, threading.Thread]: + """ + start the asyncio event running_loop in a separate thread. + + Args: + thread_name: The name of the thread to run the event running_loop in. + running_loop: The event running_loop to run in the thread. If None, a new event running_loop will be created. + + Returns: + A tuple containing the new event running_loop and the thread it is running in. + """ + new_loop = running_loop or asyncio.new_event_loop() + # Start the running_loop in a new thread + thread = threading.Thread( + target=start_loop, args=(new_loop,), name=thread_name, daemon=True + ) + # Start the thread + thread.start() + return new_loop, thread + + +def stop_loop_in_thread( + running_loop: asyncio.AbstractEventLoop, thread: threading.Thread, wait: bool = False +) -> None: + """ + Stop the event running_loop running in a separate thread and close the thread. + + Args: + running_loop: The event running_loop to stop. + thread: The thread running the event running_loop. + wait: Whether to wait for all tasks to be cancelled and the running_loop to stop. + """ + stop_loop(running_loop, wait=wait) + # Wait for the thread to join + if wait: + print("等待线程结束") + thread.join() + + +def _try_use_uvloop() -> None: + """ + Use uvloop instead of the default asyncio running_loop. + """ + import asyncio + import os + + # Check if the operating system. + if os.name == "nt": + # Windows is not supported. + logger.warning( + "Unable to use uvloop, because it is not supported on your operating system." + ) + return + + # Try import uvloop. + try: + import uvloop + except ImportError: + # uvloop is not available. + logger.warning( + "Unable to use uvloop, because it is not installed. " + "You can install it by running `pip install uvloop`." + ) + return + + # Use uvloop instead of the default asyncio running_loop. + if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy): + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +# Call the function to try to use uvloop. +_try_use_uvloop() diff --git a/dubbo/remoting/transporter.py b/dubbo/remoting/transporter.py index 48c9f43..ff68bf4 100644 --- a/dubbo/remoting/transporter.py +++ b/dubbo/remoting/transporter.py @@ -13,28 +13,66 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dubbo.common.url import URL +from dubbo.url import URL -class RemotingServer: +class Client: - pass + def __init__(self, url: URL): + self._url = url + # flag to indicate whether the client is opened + self._opened = False + # flag to indicate whether the client is connected + self._connected = False + # flag to indicate whether the client is closed + self._closed = False + + @property + def opened(self): + return self._opened + + @property + def connected(self): + return self._connected + @property + def closed(self): + return self._closed -class RemotingClient: + def open(self): + """ + Open the client. + """ + raise NotImplementedError("open() is not implemented.") + def connect(self): + """ + Connect to the server. + """ + raise NotImplementedError("connect() is not implemented.") + + def close(self): + """ + Close the client. + """ + raise NotImplementedError("close() is not implemented.") + + +class Server: + # TODO define the interface of the server. pass class Transporter: - def bind(self, url: URL) -> RemotingServer: + + def connect(self, url: URL) -> Client: """ - Bind a server. + Connect to a server. """ - pass + raise NotImplementedError("connect() is not implemented.") - def connect(self, url: URL) -> RemotingClient: + def bind(self, url: URL) -> Server: """ - Connect to a server. + Bind a server. """ - pass + raise NotImplementedError("bind() is not implemented.") diff --git a/dubbo/serialization.py b/dubbo/serialization.py index 2049eb1..3d92f27 100644 --- a/dubbo/serialization.py +++ b/dubbo/serialization.py @@ -15,9 +15,9 @@ # limitations under the License. from typing import Any -from dubbo.common.constants import common_constants -from dubbo.common.url import URL +from dubbo.constants import common_constants from dubbo.logger.logger_factory import loggerFactory +from dubbo.url import URL logger = loggerFactory.get_logger(__name__) diff --git a/dubbo/common/url.py b/dubbo/url.py similarity index 92% rename from dubbo/common/url.py rename to dubbo/url.py index b4e65a0..0072164 100644 --- a/dubbo/common/url.py +++ b/dubbo/url.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy from typing import Any, Dict, Optional from urllib import parse @@ -21,7 +22,7 @@ class URL: """ URL - Uniform Resource Locator. Args: - protocol (str): The protocol of the URL. + scheme (str): The protocol of the URL. host (str): The host of the URL. port (int): The port number of the URL. username (str): The username for URL authentication. @@ -39,7 +40,7 @@ class URL: def __init__( self, - protocol: str, + scheme: str, host: str, port: int = 0, username: str = "", @@ -48,7 +49,7 @@ def __init__( parameters: Optional[Dict[str, str]] = None, attributes: Optional[Dict[str, Any]] = None, ): - self._protocol = protocol + self._scheme = scheme self._host = host self._port = port # location -> host:port @@ -60,24 +61,24 @@ def __init__( self._attributes = attributes or {} @property - def protocol(self) -> str: + def scheme(self) -> str: """ Gets the protocol of the URL. Returns: str: The protocol of the URL. """ - return self._protocol + return self._scheme - @protocol.setter - def protocol(self, protocol: str) -> None: + @scheme.setter + def scheme(self, scheme: str) -> None: """ Sets the protocol of the URL. Args: - protocol (str): The protocol to set. + scheme (str): The protocol to set. """ - self._protocol = protocol + self._scheme = scheme @property def location(self) -> str: @@ -272,7 +273,7 @@ def build_string(self, encode: bool = False) -> str: str: The generated URL string. """ # Set protocol - url = f"{self.protocol}://" if self.protocol else "" + url = f"{self.scheme}://" if self.scheme else "" # Set auth if self.username: url += f"{self.username}" @@ -293,6 +294,23 @@ def build_string(self, encode: bool = False) -> str: url = parse.quote(url) return url + def clone(self) -> "URL": + """ + Clones the URL object. Ignores the attributes. + + Returns: + URL: The cloned URL object. + """ + return URL( + self.scheme, + self.host, + self.port, + self.username, + self.password, + self.path, + copy.deepcopy(self.parameters), + ) + def __str__(self) -> str: """ Returns the URL string when the object is converted to a string. diff --git a/tests/common/tets_url.py b/tests/common/tets_url.py index 7252500..fa4c72d 100644 --- a/tests/common/tets_url.py +++ b/tests/common/tets_url.py @@ -15,7 +15,7 @@ # limitations under the License. import unittest -from dubbo.common.url import URL +from dubbo.url import URL class TestUrl(unittest.TestCase): @@ -24,7 +24,7 @@ def test_str_to_url(self): url_0 = URL.value_of( "http://www.facebook.com/friends?param1=value1¶m2=value2" ) - self.assertEqual("http", url_0.protocol) + self.assertEqual("http", url_0.scheme) self.assertEqual("www.facebook.com", url_0.host) self.assertEqual(0, url_0.port) self.assertEqual("friends", url_0.path) @@ -32,7 +32,7 @@ def test_str_to_url(self): self.assertEqual("value2", url_0.get_parameter("param2")) url_1 = URL.value_of("ftp://username:password@192.168.1.7:21/1/read.txt") - self.assertEqual("ftp", url_1.protocol) + self.assertEqual("ftp", url_1.scheme) self.assertEqual("username", url_1.username) self.assertEqual("password", url_1.password) self.assertEqual("192.168.1.7", url_1.host) @@ -41,14 +41,14 @@ def test_str_to_url(self): self.assertEqual("1/read.txt", url_1.path) url_2 = URL.value_of("file:///home/user1/router.js?type=script") - self.assertEqual("file", url_2.protocol) + self.assertEqual("file", url_2.scheme) self.assertEqual("home/user1/router.js", url_2.path) url_3 = URL.value_of( "http%3A//www.facebook.com/friends%3Fparam1%3Dvalue1%26param2%3Dvalue2", encoded=True, ) - self.assertEqual("http", url_3.protocol) + self.assertEqual("http", url_3.scheme) self.assertEqual("www.facebook.com", url_3.host) self.assertEqual(0, url_3.port) self.assertEqual("friends", url_3.path) @@ -57,7 +57,7 @@ def test_str_to_url(self): def test_url_to_str(self): url_0 = URL( - protocol="tri", + scheme="tri", host="127.0.0.1", port=12, username="username", @@ -70,7 +70,7 @@ def test_url_to_str(self): ) url_1 = URL( - protocol="tri", + scheme="tri", host="127.0.0.1", port=12, path="path", @@ -78,5 +78,5 @@ def test_url_to_str(self): ) self.assertEqual("tri://127.0.0.1:12/path?type=a", url_1.build_string()) - url_2 = URL(protocol="tri", host="127.0.0.1", port=12, parameters={"type": "a"}) + url_2 = URL(scheme="tri", host="127.0.0.1", port=12, parameters={"type": "a"}) self.assertEqual("tri://127.0.0.1:12/?type=a", url_2.build_string()) diff --git a/tests/logger/test_logger_factory.py b/tests/logger/test_logger_factory.py index fa3016a..c3e6fd1 100644 --- a/tests/logger/test_logger_factory.py +++ b/tests/logger/test_logger_factory.py @@ -15,8 +15,8 @@ # limitations under the License. import unittest -from dubbo.common.constants import logger_constants as logger_constants -from dubbo.common.constants.logger_constants import Level +from dubbo.constants import logger_constants as logger_constants +from dubbo.constants.logger_constants import Level from dubbo.config import LoggerConfig from dubbo.logger.logger_factory import loggerFactory from dubbo.logger.logging.logger_adapter import LoggingLoggerAdapter diff --git a/tests/logger/test_logging_logger.py b/tests/logger/test_logging_logger.py index c95a9ab..9915dc0 100644 --- a/tests/logger/test_logging_logger.py +++ b/tests/logger/test_logging_logger.py @@ -15,7 +15,7 @@ # limitations under the License. import unittest -from dubbo.common.constants.logger_constants import Level +from dubbo.constants.logger_constants import Level from dubbo.config import LoggerConfig from dubbo.logger.logging.logger_adapter import LoggingLoggerAdapter From 7355cd83a188a891e6fb6e75594f995a5c181595 Mon Sep 17 00:00:00 2001 From: zaki Date: Sun, 14 Jul 2024 21:35:07 +0800 Subject: [PATCH 28/32] feat: Complete the basic functions of the client --- dubbo/_dubbo.py | 3 +- dubbo/callable.py | 37 +- dubbo/client/client.py | 16 +- .../{compressor.py => compression.py} | 8 +- dubbo/compressor/gzip_compression.py | 44 ++ dubbo/config/method_config.py | 67 --- dubbo/config/reference_config.py | 56 +-- dubbo/constants/common_constants.py | 8 +- dubbo/extension/__init__.py | 3 +- dubbo/extension/registry.py | 10 + dubbo/logger/logging/logger_adapter.py | 39 +- dubbo/protocol/result.py | 21 +- .../protocol/triple/client}/__init__.py | 0 dubbo/protocol/triple/client/calls.py | 156 +++++++ .../protocol/triple/client/stream_listener.py | 108 +++++ dubbo/protocol/triple/tri_client.py | 196 -------- dubbo/protocol/triple/tri_codec.py | 37 +- dubbo/protocol/triple/tri_constants.py | 44 ++ dubbo/protocol/triple/tri_invoker.py | 105 +++-- dubbo/protocol/triple/tri_protocol.py | 17 +- dubbo/protocol/triple/tri_results.py | 82 ++++ .../{tri_rpc_status.py => tri_status.py} | 53 +++ dubbo/remoting/aio/aio_transporter.py | 129 +++--- dubbo/remoting/aio/event_loop.py | 173 +++++++ .../remoting/aio/exceptions.py | 37 +- dubbo/remoting/aio/h2_frame.py | 240 ---------- dubbo/remoting/aio/h2_protocol.py | 368 --------------- dubbo/remoting/aio/h2_stream.py | 423 ------------------ dubbo/remoting/aio/h2_stream_handler.py | 181 -------- .../aio/http2/__init__.py} | 18 - dubbo/remoting/aio/http2/controllers.py | 348 ++++++++++++++ dubbo/remoting/aio/http2/frames.py | 134 ++++++ dubbo/remoting/aio/http2/headers.py | 195 ++++++++ dubbo/remoting/aio/http2/protocol.py | 213 +++++++++ dubbo/remoting/aio/http2/registries.py | 289 ++++++++++++ dubbo/remoting/aio/http2/stream.py | 278 ++++++++++++ dubbo/remoting/aio/http2/stream_handler.py | 169 +++++++ dubbo/remoting/aio/http2/utils.py | 76 ++++ dubbo/remoting/aio/loop.py | 150 ------- dubbo/remoting/transporter.py | 34 +- dubbo/serialization.py | 118 ++--- dubbo/url.py | 84 ++-- tests/common/tets_url.py | 4 +- 43 files changed, 2751 insertions(+), 2020 deletions(-) rename dubbo/compressor/{compressor.py => compression.py} (95%) create mode 100644 dubbo/compressor/gzip_compression.py delete mode 100644 dubbo/config/method_config.py rename {tests/loop => dubbo/protocol/triple/client}/__init__.py (100%) create mode 100644 dubbo/protocol/triple/client/calls.py create mode 100644 dubbo/protocol/triple/client/stream_listener.py delete mode 100644 dubbo/protocol/triple/tri_client.py create mode 100644 dubbo/protocol/triple/tri_constants.py create mode 100644 dubbo/protocol/triple/tri_results.py rename dubbo/protocol/triple/{tri_rpc_status.py => tri_status.py} (71%) create mode 100644 dubbo/remoting/aio/event_loop.py rename tests/loop/test_loop_manger.py => dubbo/remoting/aio/exceptions.py (58%) delete mode 100644 dubbo/remoting/aio/h2_frame.py delete mode 100644 dubbo/remoting/aio/h2_protocol.py delete mode 100644 dubbo/remoting/aio/h2_stream.py delete mode 100644 dubbo/remoting/aio/h2_stream_handler.py rename dubbo/{protocol/triple/tri_listener.py => remoting/aio/http2/__init__.py} (67%) create mode 100644 dubbo/remoting/aio/http2/controllers.py create mode 100644 dubbo/remoting/aio/http2/frames.py create mode 100644 dubbo/remoting/aio/http2/headers.py create mode 100644 dubbo/remoting/aio/http2/protocol.py create mode 100644 dubbo/remoting/aio/http2/registries.py create mode 100644 dubbo/remoting/aio/http2/stream.py create mode 100644 dubbo/remoting/aio/http2/stream_handler.py create mode 100644 dubbo/remoting/aio/http2/utils.py delete mode 100644 dubbo/remoting/aio/loop.py diff --git a/dubbo/_dubbo.py b/dubbo/_dubbo.py index 05a096f..fece509 100644 --- a/dubbo/_dubbo.py +++ b/dubbo/_dubbo.py @@ -16,8 +16,7 @@ import threading from typing import Dict, List -from dubbo.config import (ApplicationConfig, ConsumerConfig, LoggerConfig, - ProtocolConfig) +from dubbo.config import ApplicationConfig, ConsumerConfig, LoggerConfig, ProtocolConfig from dubbo.logger.logger_factory import loggerFactory logger = loggerFactory.get_logger(__name__) diff --git a/dubbo/callable.py b/dubbo/callable.py index 749dddb..0481818 100644 --- a/dubbo/callable.py +++ b/dubbo/callable.py @@ -21,39 +21,34 @@ from dubbo.url import URL -class RpcCallable: +class AbstractRpcCallable: def __init__(self, invoker: Invoker, url: URL): self._invoker = invoker self._url = url - self._service_name = self._url.path or "" - self._method_name = self._url.get_parameter(common_constants.METHOD_KEY) or "" + self._service_name = self._url.path + self._method_name = self._url.get_parameter(common_constants.METHOD_KEY) self._call_type = self._url.get_parameter(common_constants.CALL_KEY) - self._request_serializer = ( - self._url.get_attribute(common_constants.SERIALIZATION) or None - ) - self._response_serializer = ( - self._url.get_attribute(common_constants.DESERIALIZATION) or None - ) - def _do_call(self, argument: Any) -> Any: - """ - Real call method. - """ - # Create a new RpcInvocation object. - invocation = RpcInvocation( + self._serialization = self._url.attributes[common_constants.SERIALIZATION] + + def _create_invocation(self, argument: Any) -> RpcInvocation: + return RpcInvocation( self._service_name, self._method_name, argument, attributes={ common_constants.CALL_KEY: self._call_type, - common_constants.SERIALIZATION: self._request_serializer, - common_constants.DESERIALIZATION: self._response_serializer, + common_constants.SERIALIZATION: self._serialization, }, ) - # Do invoke. - result = self._invoker.invoke(invocation) - return result.get_value() + + +class RpcCallable(AbstractRpcCallable): def __call__(self, argument: Any) -> Any: - return self._do_call(argument) + # Create a new RpcInvocation + invocation = self._create_invocation(argument) + # Do invoke. + result = self._invoker.invoke(invocation) + return result.value() diff --git a/dubbo/client/client.py b/dubbo/client/client.py index ecefa8d..6ab37c3 100644 --- a/dubbo/client/client.py +++ b/dubbo/client/client.py @@ -18,9 +18,9 @@ from dubbo.callable import RpcCallable from dubbo.config import ConsumerConfig, ReferenceConfig from dubbo.constants import common_constants -from dubbo.constants.type_constants import (DeserializingFunction, - SerializingFunction) +from dubbo.constants.type_constants import DeserializingFunction, SerializingFunction from dubbo.logger.logger_factory import loggerFactory +from dubbo.serialization import Serialization logger = loggerFactory.get_logger(__name__) @@ -42,7 +42,10 @@ def unary( response_deserializer: Optional[DeserializingFunction] = None, ) -> RpcCallable: return self._callable( - common_constants.CALL_UNARY, method_name, request_serializer, response_deserializer + common_constants.CALL_UNARY, + method_name, + request_serializer, + response_deserializer, ) def client_stream( @@ -106,11 +109,12 @@ def _callable( url = invoker.get_url() # clone url - url = url.clone() + url = url.clone_without_attributes() url.add_parameter(common_constants.METHOD_KEY, method_name) url.add_parameter(common_constants.CALL_KEY, call_type) - url.add_attribute(common_constants.SERIALIZATION, request_serializer) - url.add_attribute(common_constants.DESERIALIZATION, response_deserializer) + + serialization = Serialization(request_serializer, response_deserializer) + url.attributes[common_constants.SERIALIZATION] = serialization # create callable return RpcCallable(invoker, url) diff --git a/dubbo/compressor/compressor.py b/dubbo/compressor/compression.py similarity index 95% rename from dubbo/compressor/compressor.py rename to dubbo/compressor/compression.py index 602a35b..342225b 100644 --- a/dubbo/compressor/compressor.py +++ b/dubbo/compressor/compression.py @@ -15,7 +15,10 @@ # limitations under the License. -class Compressor: +class Compression: + """ + Compression interface + """ def compress(self, data: bytes) -> bytes: """ @@ -27,9 +30,6 @@ def compress(self, data: bytes) -> bytes: """ raise NotImplementedError("compress() is not implemented.") - -class DeCompressor: - def decompress(self, data: bytes) -> bytes: """ Decompress the data diff --git a/dubbo/compressor/gzip_compression.py b/dubbo/compressor/gzip_compression.py new file mode 100644 index 0000000..803bd55 --- /dev/null +++ b/dubbo/compressor/gzip_compression.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gzip + +from dubbo.compressor.compression import Compression + + +class GzipCompression(Compression): + """ + GZIP Compression implementation + """ + + def compress(self, data: bytes) -> bytes: + """ + Compress the data using GZIP + Args: + data (bytes): Data to compress + Returns: + bytes: Compressed data + """ + return gzip.compress(data) + + def decompress(self, data: bytes) -> bytes: + """ + Decompress the data using GZIP + Args: + data (bytes): Data to decompress + Returns: + bytes: Decompressed data + """ + return gzip.decompress(data) diff --git a/dubbo/config/method_config.py b/dubbo/config/method_config.py deleted file mode 100644 index f6c2dcd..0000000 --- a/dubbo/config/method_config.py +++ /dev/null @@ -1,67 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, Optional - - -class MethodConfig: - """ - MethodConfig is a configuration class for a method. - Attributes: - _interface_name (str): The name of the interface. - _name (str): The name of the method. - _request_serialize (Optional[Callable[..., Any]]): The request serialization function. - _response_deserialize (Optional[Callable[..., Any]]): The response deserialization function. - """ - - _interface_name: str - _name: str - _request_serialize: Optional[Callable[..., Any]] - _response_deserialize: Optional[Callable[..., Any]] - - __slots__ = [ - "_interface_name", - "_name", - "_request_serialize", - "_response_deserialize", - ] - - def __init__( - self, - interface_name: str, - name: str, - request_serialize: Optional[Callable[..., Any]] = None, - response_deserialize: Optional[Callable[..., Any]] = None, - ): - self._interface_name = interface_name - self._name = name - self._request_serialize = request_serialize - self._response_deserialize = response_deserialize - - @property - def interface_name(self) -> str: - return self._interface_name - - @property - def name(self) -> str: - return self._name - - @property - def request_serialize(self) -> Optional[Callable[..., Any]]: - return self._request_serialize - - @property - def response_deserialize(self) -> Optional[Callable[..., Any]]: - return self._response_deserialize diff --git a/dubbo/config/reference_config.py b/dubbo/config/reference_config.py index 3015f50..1e1530d 100644 --- a/dubbo/config/reference_config.py +++ b/dubbo/config/reference_config.py @@ -14,9 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import threading -from typing import List, Optional +from typing import Optional, Union -from dubbo.config.method_config import MethodConfig from dubbo.extension import extensionLoader from dubbo.protocol.invoker import Invoker from dubbo.protocol.protocol import Protocol @@ -25,36 +24,24 @@ class ReferenceConfig: - _interface_name: str - _check: bool - _url: str - _protocol: str - _methods: List[MethodConfig] + __slots__ = [ + "_initialized", + "_global_lock", + "_service_name", + "_url", + "_protocol", + "_invoker", + ] - _global_lock: threading.Lock - _initialized: bool - _destroyed: bool - _protocol_ins: Optional[Protocol] - _invoker: Optional[Invoker] - - def __init__( - self, - interface_name: str, - url: str, - protocol: str, - methods: Optional[List[MethodConfig]] = None, - ): + def __init__(self, url: Union[str, URL], service_name: str): self._initialized = False self._global_lock = threading.Lock() - self._destroyed = False - self._interface_name = interface_name - self._url = url - self._protocol = protocol - self._methods = methods or [] - - self._invoker = None + self._url: URL = url if isinstance(url, URL) else URL.value_of(url) + self._service_name = service_name + self._protocol: Optional[Protocol] = None + self._invoker: Optional[Invoker] = None - def get_invoker(self): + def get_invoker(self) -> Invoker: if not self._invoker: self._do_init() return self._invoker @@ -63,14 +50,13 @@ def _do_init(self): with self._global_lock: if self._initialized: return - - clazz = extensionLoader.get_extension(Protocol, self._protocol) - # TODO set real URL - self._protocol_ins = clazz(URL.value_of(self._url)) + # Get the interface name from the URL path + self._url.path = self._service_name + self._protocol = extensionLoader.get_extension(Protocol, self._url.scheme)( + self._url + ) self._create_invoker() self._initialized = True def _create_invoker(self): - url = URL.value_of(self._url) - url.path = self._interface_name - self._invoker = self._protocol_ins.refer(url) + self._invoker = self._protocol.refer(self._url) diff --git a/dubbo/constants/common_constants.py b/dubbo/constants/common_constants.py index ebf4a96..cff24c9 100644 --- a/dubbo/constants/common_constants.py +++ b/dubbo/constants/common_constants.py @@ -25,11 +25,11 @@ CALL_CLIENT_STREAM = "client-stream" CALL_SERVER_STREAM = "server-stream" CALL_BIDI_STREAM = "bidi-stream" +ASYNC_KEY = "async" SERIALIZATION = "serialization" -DESERIALIZATION = "deserialization" -COMPRESSOR_KEY = "compressor" -DECOMPRESSOR_KEY = "decompressor" + +COMPRESSION = "compression" SERVER_KEY = "server" METHOD_KEY = "method" @@ -43,4 +43,6 @@ TRANSPORTER_SIDE_KEY = "transporter-side" TRANSPORTER_SIDE_SERVER = "server" TRANSPORTER_SIDE_CLIENT = "client" +TRANSPORTER_PROTOCOL_KEY = "protocol" +TRANSPORTER_STREAM_HANDLER_KEY = "stream-handler" TRANSPORTER_ON_CONN_CLOSE_KEY = "on-conn-close" diff --git a/dubbo/extension/__init__.py b/dubbo/extension/__init__.py index 8744a34..0da2118 100644 --- a/dubbo/extension/__init__.py +++ b/dubbo/extension/__init__.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dubbo.extension.extension_loader import \ - ExtensionLoader as _ExtensionLoader +from dubbo.extension.extension_loader import ExtensionLoader as _ExtensionLoader extensionLoader = _ExtensionLoader() diff --git a/dubbo/extension/registry.py b/dubbo/extension/registry.py index 71904b7..dac28ed 100644 --- a/dubbo/extension/registry.py +++ b/dubbo/extension/registry.py @@ -18,6 +18,7 @@ from dataclasses import dataclass from typing import Any +from dubbo.compressor.compression import Compression from dubbo.logger import LoggerAdapter from dubbo.protocol.protocol import Protocol from dubbo.remoting.transporter import Transporter @@ -44,6 +45,15 @@ class ExtendedRegistry: }, ) +"""Compression registry.""" +compressionRegistry = ExtendedRegistry( + interface=Compression, + impls={ + "gzip": "dubbo.compressor.gzip_compression.GzipCompression", + }, +) + + """Transporter registry.""" transporterRegistry = ExtendedRegistry( interface=Transporter, diff --git a/dubbo/logger/logging/logger_adapter.py b/dubbo/logger/logging/logger_adapter.py index c8a20ca..f4d36b4 100644 --- a/dubbo/logger/logging/logger_adapter.py +++ b/dubbo/logger/logging/logger_adapter.py @@ -43,7 +43,7 @@ class LoggingLoggerAdapter(LoggerAdapter): def __init__(self, config: URL): super().__init__(config) # Set level - level_name = config.parameters.get(logger_constants.LEVEL_KEY) + level_name = config.get_parameter(logger_constants.LEVEL_KEY) self._level = Level.get_level(level_name) if level_name else Level.DEBUG self._update_level() @@ -58,25 +58,21 @@ def get_logger(self, name: str) -> Logger: logger_instance = logging.getLogger(name) # clean up handlers logger_instance.handlers.clear() - parameters = self._config.parameters # Add console handler - if parameters.get( - logger_constants.CONSOLE_ENABLED_KEY, - logger_constants.DEFAULT_CONSOLE_ENABLED_VALUE, - ).lower() == common_constants.TRUE_VALUE or bool( + console_enabled = self._config.get_parameter( + logger_constants.CONSOLE_ENABLED_KEY + ) or str(logger_constants.DEFAULT_CONSOLE_ENABLED_VALUE) + if console_enabled.lower() == common_constants.TRUE_VALUE or bool( sys.stdout and sys.stdout.isatty() ): logger_instance.addHandler(self._get_console_handler()) # Add file handler - if ( - parameters.get( - logger_constants.FILE_ENABLED_KEY, - logger_constants.DEFAULT_FILE_ENABLED_VALUE, - ).lower() - == common_constants.TRUE_VALUE - ): + file_enabled = self._config.get_parameter( + logger_constants.FILE_ENABLED_KEY + ) or str(logger_constants.DEFAULT_FILE_ENABLED_VALUE) + if file_enabled.lower() == common_constants.TRUE_VALUE: logger_instance.addHandler(self._get_file_handler()) if not logger_instance.handlers: @@ -104,33 +100,36 @@ def _get_file_handler(self) -> logging.Handler: Returns: logging.Handler: The file handler. """ - parameters = self._config.parameters # Get file path - file_dir = parameters[logger_constants.FILE_DIR_KEY] + file_dir = self._config.get_parameter(logger_constants.FILE_DIR_KEY) file_name = ( - parameters[logger_constants.FILE_NAME_KEY] + self._config.get_parameter(logger_constants.FILE_NAME_KEY) or logger_constants.DEFAULT_FILE_NAME_VALUE ) file_path = os.path.join(file_dir, file_name) # Get backup count backup_count = int( - parameters.get(logger_constants.FILE_BACKUP_COUNT_KEY) + self._config.get_parameter(logger_constants.FILE_BACKUP_COUNT_KEY) or logger_constants.DEFAULT_FILE_BACKUP_COUNT_VALUE ) # Get rotate type - rotate_type = parameters.get(logger_constants.FILE_ROTATE_KEY) + rotate_type = self._config.get_parameter(logger_constants.FILE_ROTATE_KEY) # Set file Handler file_handler: logging.Handler if rotate_type == FileRotateType.SIZE.value: # Set RotatingFileHandler - max_bytes = int(parameters[logger_constants.FILE_MAX_BYTES_KEY]) + max_bytes = int( + self._config.get_parameter(logger_constants.FILE_MAX_BYTES_KEY) + ) file_handler = handlers.RotatingFileHandler( file_path, maxBytes=max_bytes, backupCount=backup_count ) elif rotate_type == FileRotateType.TIME.value: # Set TimedRotatingFileHandler - interval = int(parameters[logger_constants.FILE_INTERVAL_KEY]) + interval = int( + self._config.get_parameter(logger_constants.FILE_INTERVAL_KEY) + ) file_handler = handlers.TimedRotatingFileHandler( file_path, interval=interval, backupCount=backup_count ) diff --git a/dubbo/protocol/result.py b/dubbo/protocol/result.py index 53d0480..c263baf 100644 --- a/dubbo/protocol/result.py +++ b/dubbo/protocol/result.py @@ -29,7 +29,7 @@ def set_value(self, value: Any) -> None: """ raise NotImplementedError("set_value() is not implemented.") - def get_value(self) -> Any: + def value(self) -> Any: """ Get the value of the result """ @@ -43,8 +43,25 @@ def set_exception(self, exception: Exception) -> None: """ raise NotImplementedError("set_exception() is not implemented.") - def get_exception(self) -> Exception: + def exception(self) -> Exception: """ Get the exception to the result """ raise NotImplementedError("get_exception() is not implemented.") + + def add_attachment(self, key: str, value: Any) -> None: + """ + Add an attachment to the result + Args: + key: Key of the attachment + value: Value of the attachment + """ + raise NotImplementedError("add_attachment() is not implemented.") + + def get_attachment(self, key: str) -> Any: + """ + Get an attachment from the result + Args: + key: Key of the attachment + """ + raise NotImplementedError("get_attachment() is not implemented.") diff --git a/tests/loop/__init__.py b/dubbo/protocol/triple/client/__init__.py similarity index 100% rename from tests/loop/__init__.py rename to dubbo/protocol/triple/client/__init__.py diff --git a/dubbo/protocol/triple/client/calls.py b/dubbo/protocol/triple/client/calls.py new file mode 100644 index 0000000..2e6a184 --- /dev/null +++ b/dubbo/protocol/triple/client/calls.py @@ -0,0 +1,156 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Tuple + +from dubbo.compressor.compression import Compression +from dubbo.protocol.triple.tri_codec import TriEncoder +from dubbo.protocol.triple.tri_results import AbstractTriResult +from dubbo.protocol.triple.tri_status import TriRpcStatus +from dubbo.remoting.aio.http2.headers import Http2Headers +from dubbo.remoting.aio.http2.registries import Http2ErrorCode +from dubbo.remoting.aio.http2.stream import Http2Stream +from dubbo.serialization import Serialization + + +class ClientCall: + """ + The client call. + """ + + def __init__(self, listener: "ClientCall.Listener"): + self._listener = listener + self._stream: Optional[Http2Stream] = None + + def bind_stream(self, stream: Http2Stream) -> None: + """ + Bind stream + """ + self._stream = stream + + def send_headers(self, headers: Http2Headers) -> None: + """ + Send headers. + Args: + headers: The headers. + """ + raise NotImplementedError("send_headers() is not implemented.") + + def send_message(self, message: Any, last: bool = False) -> None: + """ + Send message. + Args: + message: The message. + last: Whether this is the last message. + """ + raise NotImplementedError("send_message() is not implemented.") + + def send_reset(self, error_code: Http2ErrorCode) -> None: + """ + Send a reset. + Args: + error_code: The error code. + """ + raise NotImplementedError("send_reset() is not implemented.") + + class Listener: + """ + The listener of the client call. + """ + + def on_message(self, message: Any) -> None: + """ + Called when a message is received. + """ + raise NotImplementedError("on_message() is not implemented.") + + def on_close( + self, rpc_status: TriRpcStatus, trailers: List[Tuple[str, str]] + ) -> None: + """ + Called when the stream is closed. + """ + raise NotImplementedError("on_close() is not implemented.") + + +class TriClientCall(ClientCall): + """ + The triple client call. + """ + + def __init__( + self, + result: AbstractTriResult, + serialization: Serialization, + compression: Optional[Compression] = None, + ): + super().__init__(TriClientCall.Listener(result, serialization)) + self._serialization = serialization + self._tri_encoder = TriEncoder(compression) + + @property + def listener(self) -> "TriClientCall.Listener": + return self._listener + + def send_headers(self, headers: Http2Headers) -> None: + """ + Send headers. + """ + self._stream.send_headers(headers, end_stream=False) + + def send_message(self, message: Any, last: bool = False) -> None: + """ + Send a message. + """ + # Serialize the message + serialized_message = self._serialization.serialize(message) + + # Encode the message + encode_message = self._tri_encoder.encode(serialized_message) + self._stream.send_data(encode_message, end_stream=last) + + def send_reset(self, error_code: Http2ErrorCode) -> None: + """ + Send a reset. + """ + self._stream.send_reset(error_code) + + class Listener(ClientCall.Listener): + """ + The listener of the triple client call. + """ + + def __init__(self, result: AbstractTriResult, serialization: Serialization): + self._result = result + self._serialization = serialization + + def on_message(self, message: Any) -> None: + """ + Called when a message is received. + """ + # Deserialize the message + deserialized_message = self._serialization.deserialize(message) + self._result.set_value(deserialized_message) + + def on_close( + self, rpc_status: TriRpcStatus, trailers: List[Tuple[str, str]] + ) -> None: + """ + Called when the stream is closed. + """ + if rpc_status.cause: + self._result.set_exception(rpc_status.cause) + # Notify the result that the stream is complete + self._result.set_value(self._result.END_SIGNAL) diff --git a/dubbo/protocol/triple/client/stream_listener.py b/dubbo/protocol/triple/client/stream_listener.py new file mode 100644 index 0000000..f757afb --- /dev/null +++ b/dubbo/protocol/triple/client/stream_listener.py @@ -0,0 +1,108 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +from dubbo.compressor.compression import Compression +from dubbo.logger.logger_factory import loggerFactory +from dubbo.protocol.triple.client.calls import ClientCall +from dubbo.protocol.triple.tri_codec import TriDecoder +from dubbo.protocol.triple.tri_constants import TripleHeaderName, TripleHeaderValue +from dubbo.protocol.triple.tri_status import TriRpcCode, TriRpcStatus +from dubbo.remoting.aio.http2.headers import Http2Headers +from dubbo.remoting.aio.http2.registries import Http2ErrorCode +from dubbo.remoting.aio.http2.stream import StreamListener + +logger = loggerFactory.get_logger(__name__) + + +class _TriDecoderListener(TriDecoder.Listener): + """ + Triple decoder listener. + """ + + def __init__(self, listener: ClientCall.Listener): + self._listener = listener + self._rpc_status = None + self._trailers = None + + def add_rpc_status(self, status: TriRpcStatus): + self._rpc_status = status + + def add_trailers(self, trailers: list): + self._trailers = trailers + + def on_message(self, message: Any) -> None: + self._listener.on_message(message) + + def close(self): + self._listener.on_close(self._rpc_status, self._trailers) + + +class TriClientStreamListener(StreamListener): + """ + Stream listener for triple client. + """ + + def __init__( + self, listener: ClientCall.Listener, compression: Optional[Compression] = None + ): + super().__init__() + self._tri_decoder_listener = _TriDecoderListener(listener) + self._tri_decoder = TriDecoder(self._tri_decoder_listener, compression) + + def on_headers(self, headers: Http2Headers, end_stream: bool) -> None: + # validate headers + validated = True + if headers.status != "200": + # Illegal response code + validated = False + logger.error(f"Invalid response code: {headers.status}") + if content_type := headers.get(TripleHeaderName.CONTENT_TYPE.value): + # Invalid content type + if not content_type.startswith(TripleHeaderValue.APPLICATION_GRPC.value): + validated = False + logger.error( + f"Invalid content type: {headers.get(TripleHeaderName.CONTENT_TYPE.value)}" + ) + else: + # Missing content type + validated = False + logger.error("Missing content type") + + if not validated: + # TODO channel by local + pass + + def on_data(self, data: bytes, end_stream: bool) -> None: + # Decode the data + self._tri_decoder.decode(data) + if end_stream: + self._tri_decoder.close() + + def on_trailers(self, headers: Http2Headers) -> None: + tri_status = TriRpcStatus( + TriRpcCode.from_code(int(headers.get(TripleHeaderName.GRPC_STATUS.value))), + description=headers.get(TripleHeaderName.GRPC_MESSAGE.value), + ) + trailers = headers.to_list() + + self._tri_decoder_listener.add_rpc_status(tri_status) + self._tri_decoder_listener.add_trailers(trailers) + + self._tri_decoder.close() + + def on_reset(self, error_code: Http2ErrorCode) -> None: + pass diff --git a/dubbo/protocol/triple/tri_client.py b/dubbo/protocol/triple/tri_client.py deleted file mode 100644 index 5240f61..0000000 --- a/dubbo/protocol/triple/tri_client.py +++ /dev/null @@ -1,196 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import queue -from typing import Any, List, Optional, Tuple - -from dubbo.compressor.compressor import Compressor, DeCompressor -from dubbo.constants import common_constants -from dubbo.constants.common_constants import CALL_CLIENT_STREAM, CALL_UNARY -from dubbo.constants.type_constants import (DeserializingFunction, - SerializingFunction) -from dubbo.extension import extensionLoader -from dubbo.protocol.result import Result -from dubbo.protocol.triple.tri_codec import TriDecoder, TriEncoder -from dubbo.remoting.aio.h2_stream import Stream -from dubbo.url import URL - - -class TriClientCall(Stream.Listener): - - def __init__( - self, - listener: "TriClientCall.Listener", - url: URL, - request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None, - ): - self._stream: Optional[Stream] = None - self._listener = listener - - # Try to get the compressor and decompressor from the URL - self._compressor = self._decompressor = None - if compressor_str := url.get_parameter(common_constants.COMPRESSOR_KEY): - self._compressor = extensionLoader.get_extension(Compressor, compressor_str) - if decompressor_str := url.get_parameter(common_constants.DECOMPRESSOR_KEY): - self._decompressor = extensionLoader.get_extension( - DeCompressor, decompressor_str - ) - - self._compressed = self._compressor is not None - self._encoder = TriEncoder(self._compressor) - self._request_serializer = request_serializer - - class TriDecoderListener(TriDecoder.Listener): - - def __init__( - self, - _listener: "TriClientCall.Listener", - _response_deserializer: Optional[DeserializingFunction] = None, - ): - self._listener = _listener - self._response_deserializer = _response_deserializer - - def on_message(self, message: bytes): - if self._response_deserializer: - message = self._response_deserializer(message) - self._listener.on_message(message) - - def close(self): - self._listener.on_complete() - - self._response_deserializer = response_deserializer - self._decoder = TriDecoder( - TriDecoderListener(self._listener, self._response_deserializer), - self._decompressor, - ) - - self._header_received = False - self._headers = None - self._trailers = None - - def bind_stream(self, stream: Stream) -> None: - """ - Bind stream - """ - self._stream = stream - - def send_headers(self, headers: List[Tuple[str, str]], last: bool = False) -> None: - """ - Send headers - Args: - headers (List[Tuple[str, str]]): Headers - last (bool): Last frame or not - """ - self._stream.send_headers(headers, end_stream=last) - - def send_message(self, message: Any, last: bool = False) -> None: - """ - Send a message - Args: - message (Any): Message to send - last (bool): Last frame or not - """ - if self._request_serializer: - data = self._request_serializer(message) - elif isinstance(message, bytes): - data = message - else: - raise TypeError("Message must be bytes or serialized by req_serializer") - - # Encode data - frame_payload = self._encoder.encode(data, self._compressed) - # Send data frame - self._stream.send_data(frame_payload, end_stream=last) - - def on_headers(self, headers: List[Tuple[str, str]]) -> None: - if not self._header_received: - self._headers = headers - self._header_received = True - else: - # receive trailers - self._trailers = headers - - def on_data(self, data: bytes) -> None: - self._decoder.decode(data) - - def on_complete(self) -> None: - self._decoder.close() - - def on_reset(self, err_code: int) -> None: - # TODO: handle reset - pass - - class Listener: - - def on_message(self, message: Any) -> None: - """ - Callback when message is received - """ - raise NotImplementedError("on_message() is not implemented") - - def on_complete(self) -> None: - """ - Callback when the stream is complete - """ - raise NotImplementedError("on_complete() is not implemented") - - -class TriResult(Result): - """ - Triple result - """ - - END_SIGNAL = object() - - def __init__(self, call_type: str): - self._call_type = call_type - self._value_queue = queue.Queue() - self._exception = None - - def set_value(self, value: Any) -> None: - self._value_queue.put(value) - if self._call_type in [CALL_UNARY, CALL_CLIENT_STREAM]: - # Notify the caller that the value is ready - self._value_queue.put(self.END_SIGNAL) - - def get_value(self) -> Any: - if self._call_type in [CALL_UNARY, CALL_CLIENT_STREAM]: - return self._get_single_value() - else: - return self._iterating_values() - - def _get_single_value(self) -> Any: - value = self._value_queue.get() - if value is self.END_SIGNAL: - return None - return value - - def _iterating_values(self) -> Any: - while True: - # block until the value is ready - value = self._value_queue.get() - if value is self.END_SIGNAL: - # break the loop when the value is end signal - break - yield value - - def set_exception(self, exception: Exception) -> None: - # close the value queue - self._value_queue.put(None) - self._exception = exception - - def get_exception(self) -> Exception: - return self._exception diff --git a/dubbo/protocol/triple/tri_codec.py b/dubbo/protocol/triple/tri_codec.py index b0711a7..7cd227b 100644 --- a/dubbo/protocol/triple/tri_codec.py +++ b/dubbo/protocol/triple/tri_codec.py @@ -16,7 +16,7 @@ import struct from typing import Optional -from dubbo.compressor.compressor import Compressor, DeCompressor +from dubbo.compressor.compression import Compression """ gRPC Message Format Diagram @@ -42,29 +42,28 @@ class TriEncoder: This class is responsible for encoding the gRPC message format, which is composed of a header and payload. Args: - compressor (Optional[Compressor]): The compressor to use for compressing the payload. + compression (Optional[Compression]): The Compression to use for compressing or decompressing the payload. """ HEADER_LENGTH: int = 5 COMPRESSED_FLAG_MASK: int = 1 - def __init__(self, compressor: Optional[Compressor]): - self._compressor: Optional[Compressor] = compressor + def __init__(self, compression: Optional[Compression]): + self._compression = compression - def encode(self, message: bytes, compressed: bool = False) -> bytes: + def encode(self, message: bytes) -> bytes: """ Encode the message into the gRPC message format. Args: message (bytes): The message to encode. - compressed (bool): Whether to compress the message. Returns: bytes: The encoded message in gRPC format. """ - compressed_flag = COMPRESSED_FLAG_MASK if compressed else 0 - if compressed: + compressed_flag = COMPRESSED_FLAG_MASK if self._compression else 0 + if self._compression: # Compress the payload - message = self._compressor.compress(message) + message = self._compression.compress(message) message_length = len(message) if message_length > 0xFFFFFFFF: @@ -82,18 +81,18 @@ class TriDecoder: Args: listener (TriDecoder.Listener): The listener to deliver the decoded payload to. - decompressor (Optional[DeCompressor]): The decompressor to use for decompressing the payload. + compression (Optional[Compression]): The Compression to use for compressing or decompressing the payload. """ def __init__( self, listener: "TriDecoder.Listener", - decompressor: Optional[DeCompressor], + compression: Optional[Compression], ): # store data for decoding self._accumulate = bytearray() self._listener = listener - self._decompressor = decompressor + self._compression = compression self._state = HEADER self._required_length = HEADER_LENGTH @@ -107,21 +106,21 @@ def __init__( self._closing = False self._closed = False - def decode(self, data: bytes): + def decode(self, data: bytes) -> None: """ Process the incoming bytes, decoding the gRPC message and delivering the payload to the listener. """ self._accumulate.extend(data) self._do_decode() - def close(self): + def close(self) -> None: """ Close the decoder and listener. """ self._closing = True self._do_decode() - def _do_decode(self): + def _do_decode(self) -> None: """ Deliver the accumulated bytes to the listener, processing the header and payload as necessary. """ @@ -143,13 +142,13 @@ def _do_decode(self): finally: self._decoding = False - def _has_enough_bytes(self): + def _has_enough_bytes(self) -> bool: """ Check if the accumulated bytes are enough to process the header or payload """ return len(self._accumulate) >= self._required_length - def _process_header(self): + def _process_header(self) -> None: """ Processes the GRPC compression header which is composed of the compression flag and the outer frame length. """ @@ -165,7 +164,7 @@ def _process_header(self): # Continue to process the payload self._state = PAYLOAD - def _process_payload(self): + def _process_payload(self) -> None: """ Processes the GRPC message body, which depending on frame header flags may be compressed. """ @@ -174,7 +173,7 @@ def _process_payload(self): if self._compressed: # Decompress the payload - payload_bytes = self._decompressor.decompress(payload_bytes) + payload_bytes = self._compression.decompress(payload_bytes) self._listener.on_message(bytes(payload_bytes)) diff --git a/dubbo/protocol/triple/tri_constants.py b/dubbo/protocol/triple/tri_constants.py new file mode 100644 index 0000000..34e3120 --- /dev/null +++ b/dubbo/protocol/triple/tri_constants.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum + + +class TripleHeaderName(enum.Enum): + """ + Header names used in triple protocol. + """ + + CONTENT_TYPE = "content-type" + + TE = "te" + GRPC_STATUS = "grpc-status" + GRPC_MESSAGE = "grpc-message" + GRPC_STATUS_DETAILS_BIN = "grpc-status-details-bin" + GRPC_TIMEOUT = "grpc-timeout" + GRPC_ENCODING = "grpc-encoding" + GRPC_ACCEPT_ENCODING = "grpc-accept-encoding" + + +class TripleHeaderValue(enum.Enum): + """ + Header values used in triple protocol. + """ + + TRAILERS = "trailers" + HTTP = "http" + HTTPS = "https" + APPLICATION_GRPC_PROTO = "application/grpc+proto" + APPLICATION_GRPC = "application/grpc" diff --git a/dubbo/protocol/triple/tri_invoker.py b/dubbo/protocol/triple/tri_invoker.py index 56f60a9..c23bf7f 100644 --- a/dubbo/protocol/triple/tri_invoker.py +++ b/dubbo/protocol/triple/tri_invoker.py @@ -13,41 +13,45 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple +from typing import Optional +from dubbo.compressor.compression import Compression from dubbo.constants import common_constants +from dubbo.extension import extensionLoader from dubbo.logger.logger_factory import loggerFactory from dubbo.protocol.invocation import Invocation, RpcInvocation from dubbo.protocol.invoker import Invoker from dubbo.protocol.result import Result -from dubbo.protocol.triple.tri_client import TriClientCall, TriResult -from dubbo.remoting.aio.h2_stream_handler import StreamHandler +from dubbo.protocol.triple.client.calls import TriClientCall +from dubbo.protocol.triple.client.stream_listener import TriClientStreamListener +from dubbo.protocol.triple.tri_constants import TripleHeaderName, TripleHeaderValue +from dubbo.protocol.triple.tri_results import TriResult +from dubbo.remoting.aio.http2.headers import Http2Headers, MethodType +from dubbo.remoting.aio.http2.stream_handler import StreamClientMultiplexHandler from dubbo.remoting.transporter import Client from dubbo.url import URL logger = loggerFactory.get_logger(__name__) -class TriClientCallListener(TriClientCall.Listener): - - def __init__(self, result: TriResult): - self._result = result - - def on_message(self, message: Any) -> None: - # Set the message to the result - self._result.set_value(message) - - def on_complete(self) -> None: - # Set the end signal to the result - self._result.set_value(self._result.END_SIGNAL) - - class TriInvoker(Invoker): + """ + Triple invoker. + """ - def __init__(self, url: URL, client: Client, stream_handler: StreamHandler): + def __init__( + self, url: URL, client: Client, stream_multiplexer: StreamClientMultiplexHandler + ): self._url = url self._client = client - self._stream_handler = stream_handler + self._stream_multiplexer = stream_multiplexer + + self._compression: Optional[Compression] = None + compression_type = url.get_parameter(common_constants.COMPRESSION) + if compression_type: + self._compression = extensionLoader.get_extension( + Compression, compression_type + ) self._destroyed = False @@ -55,23 +59,21 @@ def invoke(self, invocation: RpcInvocation) -> Result: call_type = invocation.get_attribute(common_constants.CALL_KEY) result = TriResult(call_type) - # TODO Return an exception result - if self.destroyed: - logger.warning("The invoker has been destroyed.") - raise Exception("The invoker has been destroyed.") - elif not self._client.connected: - pass + if not self._client.is_connected(): + # Reconnect the client + self._client.reconnect() - # Create a new TriClientCall object + # Create a new TriClientCall tri_client_call = TriClientCall( - TriClientCallListener(result), - url=self._url, - request_serializer=invocation.get_attribute(common_constants.SERIALIZATION), - response_deserializer=invocation.get_attribute( - common_constants.DESERIALIZATION - ), + result, + serialization=invocation.get_attribute(common_constants.SERIALIZATION), + compression=self._compression, + ) + + # Create a new stream + stream = self._stream_multiplexer.create( + TriClientStreamListener(tri_client_call.listener, self._compression) ) - stream = self._stream_handler.create(tri_client_call) tri_client_call.bind_stream(stream) if call_type in ( @@ -100,27 +102,32 @@ def _invoke_stream(self, call: TriClientCall, invocation: Invocation) -> None: next_message = message call.send_message(next_message, last=True) - def _create_headers(self, invocation: Invocation) -> List[Tuple[str, str]]: - - headers = [ - (":method", "POST"), - (":authority", self._url.location), - (":scheme", self._url.scheme), - ( - ":path", - f"/{invocation.get_service_name()}/{invocation.get_method_name()}", - ), - ("content-type", "application/grpc+proto"), - ("te", "trailers"), - ] - # TODO Add more headers information + def _create_headers(self, invocation: Invocation) -> Http2Headers: + + headers = Http2Headers() + headers.scheme = TripleHeaderValue.HTTP.value + headers.method = MethodType.POST + headers.authority = self._url.location + # set path + path = "" + if invocation.get_service_name(): + path += f"/{invocation.get_service_name()}" + path += f"/{invocation.get_method_name()}" + headers.path = path + + # set content type + headers.content_type = TripleHeaderValue.APPLICATION_GRPC_PROTO.value + + # set te + headers.add(TripleHeaderName.TE.value, TripleHeaderValue.TRAILERS.value) + return headers def get_url(self) -> URL: return self._url def is_available(self) -> bool: - return self._client.connected + return self._client.is_connected() @property def destroyed(self) -> bool: @@ -129,5 +136,5 @@ def destroyed(self) -> bool: def destroy(self) -> None: self._client.close() self._client = None - self._stream_handler = None + self._stream_multiplexer = None self._url = None diff --git a/dubbo/protocol/triple/tri_protocol.py b/dubbo/protocol/triple/tri_protocol.py index 1f9e6e6..4c28625 100644 --- a/dubbo/protocol/triple/tri_protocol.py +++ b/dubbo/protocol/triple/tri_protocol.py @@ -21,8 +21,8 @@ from dubbo.protocol.invoker import Invoker from dubbo.protocol.protocol import Protocol from dubbo.protocol.triple.tri_invoker import TriInvoker -from dubbo.remoting.aio.h2_protocol import H2Protocol -from dubbo.remoting.aio.h2_stream_handler import ClientStreamHandler +from dubbo.remoting.aio.http2.protocol import Http2Protocol +from dubbo.remoting.aio.http2.stream_handler import StreamClientMultiplexHandler from dubbo.remoting.transporter import Transporter from dubbo.url import URL @@ -45,14 +45,17 @@ def refer(self, url: URL) -> Invoker: Args: url (URL): The URL of the remote service. """ - # TODO Simply create it here, then set up a more appropriate configuration that can be configured by the user executor = ThreadPoolExecutor(thread_name_prefix="dubbo-tri-") # Create a stream handler - stream_handler = ClientStreamHandler(executor) - url.add_attribute("protocol", H2Protocol) - url.add_attribute("stream_handler", stream_handler) + stream_multiplexer = StreamClientMultiplexHandler(executor) + # set stream handler and protocol + url.attributes[common_constants.TRANSPORTER_STREAM_HANDLER_KEY] = ( + stream_multiplexer + ) + url.attributes[common_constants.TRANSPORTER_PROTOCOL_KEY] = Http2Protocol + # Create a client client = self._transporter.connect(url) - invoker = TriInvoker(url, client, stream_handler) + invoker = TriInvoker(url, client, stream_multiplexer) self._invokers.append(invoker) return invoker diff --git a/dubbo/protocol/triple/tri_results.py b/dubbo/protocol/triple/tri_results.py new file mode 100644 index 0000000..62d4a27 --- /dev/null +++ b/dubbo/protocol/triple/tri_results.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import queue +from typing import Any, Dict, Optional + +from dubbo.constants.common_constants import CALL_CLIENT_STREAM, CALL_UNARY +from dubbo.protocol.result import Result + + +class AbstractTriResult(Result): + """ + The abstract result. + """ + + END_SIGNAL = object() + + def __init__(self, call_type: str): + self.call_type = call_type + self._exception: Optional[Exception] = None + self._attachments: Dict[str, Any] = {} + + def set_exception(self, exception: Exception) -> None: + self._exception = exception + + def exception(self) -> Exception: + return self._exception + + def add_attachment(self, key: str, value: Any) -> None: + self._attachments[key] = value + + def get_attachment(self, key: str) -> Any: + return self._attachments.get(key) + + +class TriResult(AbstractTriResult): + """ + The triple result. + """ + + def __init__(self, call_type: str): + super().__init__(call_type) + self._values = queue.Queue() + + def set_value(self, value: Any) -> None: + """ + Set the value. + """ + self._values.put(value) + + def value(self) -> Any: + """ + Get the value. + """ + if self.call_type in [CALL_UNARY, CALL_CLIENT_STREAM]: + return self._get_single_value() + else: + return self._iterating_values() + + def _get_single_value(self) -> Any: + """ + Get the single value. + """ + return value if (value := self._values.get()) is not self.END_SIGNAL else None + + def _iterating_values(self) -> Any: + """ + Iterate the values. + """ + return iter(lambda: self._values.get(), self.END_SIGNAL) diff --git a/dubbo/protocol/triple/tri_rpc_status.py b/dubbo/protocol/triple/tri_status.py similarity index 71% rename from dubbo/protocol/triple/tri_rpc_status.py rename to dubbo/protocol/triple/tri_status.py index 98af7a5..c767c24 100644 --- a/dubbo/protocol/triple/tri_rpc_status.py +++ b/dubbo/protocol/triple/tri_status.py @@ -14,44 +14,97 @@ # See the License for the specific language governing permissions and # limitations under the License. import enum +from typing import Optional class TriRpcCode(enum.Enum): """ + RPC status codes. See https://github.com/grpc/grpc/blob/master/doc/statuscodes.md """ # Not an error; returned on success. OK = 0 + # The operation was cancelled, typically by the caller. CANCELLED = 1 + # Unknown error. UNKNOWN = 2 + # The client specified an invalid argument. INVALID_ARGUMENT = 3 + # The deadline expired before the operation could complete. DEADLINE_EXCEEDED = 4 + # Some requested entity (e.g., file or directory) was not found NOT_FOUND = 5 + # The entity that a client attempted to create (e.g., file or directory) already exists. ALREADY_EXISTS = 6 + # The caller does not have permission to execute the specified operation. PERMISSION_DENIED = 7 + # Some resource has been exhausted, perhaps a per-user quota, or perhaps the entire file system is out of space. RESOURCE_EXHAUSTED = 8 + # The operation was rejected because the system is not in a state required for the operation's execution. FAILED_PRECONDITION = 9 + # The operation was aborted, typically due to a concurrency issue such as a sequencer check failure or transaction abort. ABORTED = 10 + # The operation was attempted past the valid range. OUT_OF_RANGE = 11 + # The operation is not implemented or is not supported/enabled in this service. UNIMPLEMENTED = 12 + # Internal errors. INTERNAL = 13 + # The service is currently unavailable. UNAVAILABLE = 14 + # Unrecoverable data loss or corruption. DATA_LOSS = 15 + # The request does not have valid authentication credentials for the operation. UNAUTHENTICATED = 16 + + @classmethod + def from_code(cls, code: int) -> "TriRpcCode": + """ + Get the RPC status code from the given code. + Args: + code: The RPC status code. + """ + for rpc_code in cls: + if rpc_code.value == code: + return rpc_code + return cls.UNKNOWN + + +class TriRpcStatus: + """ + RPC status. + Args: + code: RPC status code. + cause: Optional exception that caused the RPC status. + description: Optional description of the RPC status. + """ + + def __init__( + self, + code: TriRpcCode, + cause: Optional[Exception] = None, + description: Optional[str] = None, + ): + self.code = code + self.cause = cause + self.description = description + + def __repr__(self): + return f"TriRpcStatus(code={self.code}, cause={self.cause}, description={self.description})" diff --git a/dubbo/remoting/aio/aio_transporter.py b/dubbo/remoting/aio/aio_transporter.py index 1e6e128..dc97db4 100644 --- a/dubbo/remoting/aio/aio_transporter.py +++ b/dubbo/remoting/aio/aio_transporter.py @@ -14,13 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +import concurrent import threading -import uuid -from typing import Optional, Tuple +from typing import Optional from dubbo.constants import common_constants from dubbo.logger.logger_factory import loggerFactory -from dubbo.remoting.aio import loop +from dubbo.remoting.aio.event_loop import EventLoop +from dubbo.remoting.aio.exceptions import RemotingException from dubbo.remoting.transporter import Client, Server, Transporter from dubbo.url import URL @@ -38,99 +39,102 @@ def __init__(self, url: URL): super().__init__(url) # Set the side of the transporter to client. + self._protocol = None + + # the event to indicate the connection status of the client + self._connect_event = threading.Event() + # the event to indicate the close status of the client + self._close_future = concurrent.futures.Future() + self._closing = False + self._url.add_parameter( common_constants.TRANSPORTER_SIDE_KEY, common_constants.TRANSPORTER_SIDE_CLIENT, ) + self._url.attributes["connect-event"] = self._connect_event + self._url.attributes["close-future"] = self._close_future - # Set connection closed function - def _connection_lost(exc: Optional[Exception]) -> None: - if exc: - logger.error("Connection lost", exc) - self._connected = False - - self._url.add_attribute( - common_constants.TRANSPORTER_ON_CONN_CLOSE_KEY, _connection_lost - ) - - self._thread: Optional[threading.Thread] = None - self._loop: Optional[asyncio.AbstractEventLoop] = None + self._event_loop: Optional[EventLoop] = None - self._transport: Optional[asyncio.Transport] = None - self._protocol: Optional[asyncio.Protocol] = None - - self._closing = False - - # Open and connect the client - self.open() + # connect to the server self.connect() - def open(self) -> None: + def is_connected(self) -> bool: """ - Create a thread and run asyncio loop in it. + Check if the client is connected. """ - self._loop, self._thread = loop.start_loop_in_thread( - f"dubbo-aio-client-{uuid.uuid4()}" - ) - self._opened = True + return self._connect_event.is_set() - def _create_protocol(self) -> asyncio.Protocol: + def is_closed(self) -> bool: """ - Create the protocol. + Check if the client is closed. """ + return self._close_future.done() or self._closing - return self._url.attributes["protocol"](self._url) + def reconnect(self) -> None: + """ + Reconnect to the server. + """ + self.close() + self._connect_event = threading.Event() + self._close_future = concurrent.futures.Future() + self.connect() def connect(self) -> None: """ Connect to the server. """ - if not self._opened: - raise RuntimeError("The client is not opened yet.") - elif self._closed: - raise RuntimeError("The client is closed.") + if self.is_connected(): + return + elif self.is_closed(): + raise RemotingException("The client is closed.") - async def _inner_connect() -> Tuple[asyncio.Transport, asyncio.Protocol]: + async def _inner_operate(): running_loop = asyncio.get_running_loop() - - transport, protocol = await running_loop.create_connection( - lambda: self._url.get_attribute("protocol")(self._url), + _, protocol = await running_loop.create_connection( + lambda: self._url.attributes[common_constants.TRANSPORTER_PROTOCOL_KEY]( + self._url + ), self._url.host, self._url.port, ) - return transport, protocol + return protocol - future = asyncio.run_coroutine_threadsafe(_inner_connect(), self._loop) + # Run the connection logic in the event loop. + if self._event_loop: + self._event_loop.stop() + self._event_loop = EventLoop() + self._event_loop.start() + future = asyncio.run_coroutine_threadsafe( + _inner_operate(), self._event_loop.loop + ) try: - self._transport, self._protocol = future.result() - self._connected = True - logger.info( - f"Connected to the server: ip={self._url.host}, port={self._url.port}" - ) - except Exception as e: - logger.error(f"Failed to connect to the server: {e}") - raise e + self._protocol = future.result() + except ConnectionRefusedError as e: + raise RemotingException("Failed to connect to the server") from e def close(self) -> None: """ - Close the client. just stop the transport. + Close the client. """ - if not self._opened: - raise RuntimeError("The client is not opened yet.") - if self._closing or self._closed: + if self.is_closed(): return self._closing = True - try: - # Close the transport - self._transport.close() - self._connected = False - # Stop the loop - loop.stop_loop_in_thread(self._loop, self._thread) - self._closed = True + self._protocol.close() + if exc := self._protocol.exception(): + raise RemotingException(f"Failed to close the client: {exc}") + except Exception as e: + if not isinstance(e, RemotingException): + # Ignore the exception if it is not RemotingException + pass + else: + # Re-raise RemotingException + raise e finally: + self._event_loop.stop() self._closing = False @@ -142,11 +146,6 @@ class AioServer(Server): def __init__(self, url: URL): self._url = url # Set the side of the transporter to server. - self._url.add_parameter( - common_constants.TRANSPORTER_SIDE_KEY, - common_constants.TRANSPORTER_SIDE_SERVER, - ) - # TODO implement the server class AioTransporter(Transporter): diff --git a/dubbo/remoting/aio/event_loop.py b/dubbo/remoting/aio/event_loop.py new file mode 100644 index 0000000..26de787 --- /dev/null +++ b/dubbo/remoting/aio/event_loop.py @@ -0,0 +1,173 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import threading +import uuid +from typing import Optional + +from dubbo.logger.logger_factory import loggerFactory + +logger = loggerFactory.get_logger(__name__) + + +def _try_use_uvloop() -> None: + """ + Use uvloop instead of the default asyncio running_loop. + """ + import asyncio + import os + + # Check if the operating system. + if os.name == "nt": + # Windows is not supported. + logger.warning( + "Unable to use uvloop, because it is not supported on your operating system." + ) + return + + # Try import uvloop. + try: + import uvloop + except ImportError: + # uvloop is not available. + logger.warning( + "Unable to use uvloop, because it is not installed. " + "You can install it by running `pip install uvloop`." + ) + return + + # Use uvloop instead of the default asyncio running_loop. + if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy): + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +# Call the function to try to use uvloop. +_try_use_uvloop() + + +class EventLoop: + + def __init__(self, in_other_tread: bool = True): + self._in_other_tread = in_other_tread + # The event loop to run the asynchronous function. + self._loop = asyncio.new_event_loop() + # The thread to run the event loop. + self._thread: Optional[threading.Thread] = ( + None if in_other_tread else threading.current_thread() + ) + + self._started = False + self._stopped = False + + # The lock to protect the event loop. + self._lock = threading.Lock() + + @property + def loop(self): + """ + Get the event loop. + Returns: + The event loop. + """ + return self._loop + + @property + def thread(self) -> Optional[threading.Thread]: + """ + Get the thread of the event loop. + Returns: + The thread of the event loop. If not yet started, this is None. + """ + return self._thread + + def check_thread(self) -> bool: + """ + Check if the current thread is the event loop thread. + Returns: + If the current thread is the event loop thread, return True. Otherwise, return False. + """ + return threading.current_thread().ident == self._thread.ident + + def is_started(self) -> bool: + """ + Check if the event loop is started. + """ + return self._started + + def start(self): + """ + Start the asyncio event loop. + """ + if self._started: + return + with self._lock: + self._started = True + self._stopped = False + if self._in_other_tread: + self._start_in_thread() + else: + self._start() + + def _start(self) -> None: + """ + Real start the asyncio event loop in current thread. + """ + asyncio.set_event_loop(self._loop) + self._loop.run_forever() + + def _start_in_thread(self) -> None: + """ + Real Start the asyncio event loop in a separate thread. + """ + thread_name = f"dubbo-asyncio-loop-{str(uuid.uuid4())}" + thread = threading.Thread(target=self._start, name=thread_name, daemon=True) + thread.start() + self._thread = thread + + def stop(self, wait: bool = False) -> None: + """ + Stop the asyncio event loop. + """ + if self._stopped: + return + with self._lock: + signal = threading.Event() + asyncio.run_coroutine_threadsafe(self._stop(signal=signal), self._loop) + # Wait for the running_loop to stop + if wait: + signal.wait() + if self._in_other_tread: + self._thread.join() + self._stopped = True + self._started = False + + async def _stop(self, signal: threading.Event) -> None: + """ + Real stop the asyncio event loop. + """ + # Cancel all tasks + tasks = [ + task + for task in asyncio.all_tasks(self._loop) + if task is not asyncio.current_task() + ] + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + # Stop the event running_loop + self._loop.stop() + # Set the signal + signal.set() diff --git a/tests/loop/test_loop_manger.py b/dubbo/remoting/aio/exceptions.py similarity index 58% rename from tests/loop/test_loop_manger.py rename to dubbo/remoting/aio/exceptions.py index 835b92c..4f3d1d6 100644 --- a/tests/loop/test_loop_manger.py +++ b/dubbo/remoting/aio/exceptions.py @@ -13,25 +13,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -import unittest -from dubbo.loop.loop_manger import LoopManager +class RemotingException(RuntimeError): + """ + The base exception class for remoting. + """ -async def _loop_task(): - while True: - print("loop task") - await asyncio.sleep(1) + def __init__(self, message: str): + super().__init__(message) -class TestLoopManager(unittest.TestCase): +class ProtocolException(RemotingException): + """ + The exception class for protocol errors. + """ - def test_use_client(self): - loop_manager = LoopManager() - loop = loop_manager.get_client_loop() - asyncio.run_coroutine_threadsafe(_loop_task(), loop) - print("loop task started, waiting for 3 seconds...") - asyncio.run(asyncio.sleep(3)) - loop_manager.destroy_client_loop() - print("loop task stopped.") + def __init__(self, message: str): + super().__init__(message) + + +class StreamException(RemotingException): + """ + The exception class for stream errors. + """ + + def __init__(self, message: str): + super().__init__(message) diff --git a/dubbo/remoting/aio/h2_frame.py b/dubbo/remoting/aio/h2_frame.py deleted file mode 100644 index 0cdc022..0000000 --- a/dubbo/remoting/aio/h2_frame.py +++ /dev/null @@ -1,240 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import enum -import sys -import time -from typing import Any, Dict, Optional - -from h2.events import (DataReceived, Event, RequestReceived, ResponseReceived, - StreamReset, TrailersReceived, WindowUpdated) - - -class H2FrameType(enum.Enum): - """ - Enum class representing HTTP/2 frame types. - """ - - # Data frame, carries HTTP message bodies. - DATA = 0x0 - # Headers frame, carries HTTP headers. - HEADERS = 0x1 - # Priority frame, specifies the priority of a stream. - PRIORITY = 0x2 - # Reset Stream frame, cancels a stream. - RST_STREAM = 0x3 - # Settings frame, exchanges configuration parameters. - SETTINGS = 0x4 - # Push Promise frame, used by the server to push resources. - PUSH_PROMISE = 0x5 - # Ping frame, measures round-trip time and checks connectivity. - PING = 0x6 - # Goaway frame, signals that the connection will be closed. - GOAWAY = 0x7 - # Window Update frame, manages flow control window size. - WINDOW_UPDATE = 0x8 - # Continuation frame, transmits large header blocks. - CONTINUATION = 0x9 - - -class H2Frame: - """ - HTTP/2 frame class. It is used to represent an HTTP/2 frame. - Args: - stream_id: The stream identifier. - frame_type: The frame type. - data: The data to send. such as: HEADERS: List[Tuple[str, str]], DATA: bytes, END_STREAM: None or bytes. - end_stream: Whether the stream is ended. - attributes: The attributes of the frame. - """ - - def __init__( - self, - stream_id: int, - frame_type: H2FrameType, - data: Any = None, - end_stream: bool = False, - attributes: Optional[Dict[str, Any]] = None, - ): - self._stream_id = stream_id - self._frame_type = frame_type - self._data = data - self._end_stream = end_stream - self._attributes = attributes or {} - - # The timestamp of the generated frame. -> comparison for Priority Queue - self._timestamp = int(round(time.time() * 1000)) - - @property - def stream_id(self) -> int: - return self._stream_id - - @property - def frame_type(self) -> H2FrameType: - return self._frame_type - - @property - def data(self) -> Any: - return self._data - - @data.setter - def data(self, data: Any) -> None: - self._data = data - - @property - def end_stream(self) -> bool: - return self._end_stream - - @property - def attributes(self) -> Dict[str, Any]: - return self._attributes - - def __lt__(self, other: "H2Frame") -> bool: - return self._timestamp < other._timestamp - - def __str__(self): - return ( - f"H2Frame(stream_id={self.stream_id}, " - f"frame_type={self.frame_type}, " - f"data={self.data}, " - f"end_stream={self.end_stream}, " - f"attributes={self.attributes})" - ) - - -DATA_COMPLETED_FRAME: H2Frame = H2Frame(0, H2FrameType.DATA, b"") -# Make use of the infinity timestamp to ensure that the DATA_COMPLETED_FRAME is always at the end of the data queue. -DATA_COMPLETED_FRAME._timestamp = sys.maxsize - - -class H2FrameUtils: - """ - Utility class for creating HTTP/2 frames. - """ - - @staticmethod - def create_headers_frame( - stream_id: int, - headers: list[tuple[str, str]], - end_stream: bool = False, - attributes: Optional[Dict[str, str]] = None, - ) -> H2Frame: - """ - Create a headers frame. - Args: - stream_id: The stream identifier. - headers: The headers to send. - end_stream: Whether the stream is ended. - attributes: The attributes of the frame. - Returns: - The headers frame. - """ - return H2Frame(stream_id, H2FrameType.HEADERS, headers, end_stream, attributes) - - @staticmethod - def create_data_frame( - stream_id: int, - data: bytes, - end_stream: bool = False, - attributes: Optional[Dict[str, str]] = None, - ) -> H2Frame: - """ - Create a data frame. - Args: - stream_id: The stream identifier. - data: The data to send. - end_stream: Whether the stream is ended. - attributes: The attributes of the frame. - Returns: - The data frame. - """ - return H2Frame(stream_id, H2FrameType.DATA, data, end_stream, attributes) - - @staticmethod - def create_reset_stream_frame( - stream_id: int, - error_code: int, - attributes: Optional[Dict[str, str]] = None, - ) -> H2Frame: - """ - Create a reset stream frame. - Args: - stream_id: The stream identifier. - error_code: The error code. - attributes: The attributes of the frame. - Returns: - The reset stream frame. - """ - return H2Frame( - stream_id, - H2FrameType.RST_STREAM, - error_code, - end_stream=True, - attributes=attributes, - ) - - @staticmethod - def create_window_update_frame( - stream_id: int, - increment: int, - attributes: Optional[Dict[str, str]] = None, - ) -> H2Frame: - """ - Create a window update frame. - Args: - stream_id: The stream identifier. - increment: The increment. - attributes: The attributes of the frame. - Returns: - The window update frame. - """ - return H2Frame( - stream_id, H2FrameType.WINDOW_UPDATE, increment, attributes=attributes - ) - - @staticmethod - def create_frame_by_event(event: Event) -> Optional[H2Frame]: - """ - Create a frame by the h2.events.Event. - Args: - event: The h2.events.Event. - Returns: - The H2Frame. None if the event is not supported or not implemented. - """ - if isinstance(event, (RequestReceived, ResponseReceived)): - # The headers frame. - return H2FrameUtils.create_headers_frame( - event.stream_id, event.headers, event.stream_ended is not None - ) - elif isinstance(event, TrailersReceived): - return H2FrameUtils.create_headers_frame( - event.stream_id, event.headers, end_stream=True - ) - elif isinstance(event, DataReceived): - # The data frame. - return H2FrameUtils.create_data_frame( - event.stream_id, - event.data, - end_stream=event.stream_ended is not None, - attributes={"flow_controlled_length": event.flow_controlled_length}, - ) - elif isinstance(event, StreamReset): - # The reset stream frame. - return H2FrameUtils.create_reset_stream_frame( - event.stream_id, event.error_code - ) - elif isinstance(event, WindowUpdated): - # The window update frame. - return H2FrameUtils.create_window_update_frame(event.stream_id, event.delta) diff --git a/dubbo/remoting/aio/h2_protocol.py b/dubbo/remoting/aio/h2_protocol.py deleted file mode 100644 index dd1c73f..0000000 --- a/dubbo/remoting/aio/h2_protocol.py +++ /dev/null @@ -1,368 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import threading -from concurrent.futures import Future as ThreadingFuture -from typing import Dict, Optional, Tuple - -from h2.config import H2Configuration -from h2.connection import H2Connection - -from dubbo.constants import common_constants -from dubbo.logger.logger_factory import loggerFactory -from dubbo.remoting.aio.h2_frame import H2Frame, H2FrameType, H2FrameUtils -from dubbo.remoting.aio.h2_stream_handler import StreamHandler -from dubbo.url import URL - -logger = loggerFactory.get_logger(__name__) - - -class DataFlowControl: - """ - DataFlowControl is responsible for managing HTTP/2 data flow, handling flow control, - and ensuring data frames are sent according to the HTTP/2 flow control rules. - - Note: - The class is not thread-safe and does not need to be designed as thread-safe - because there can be only one DataFlowControl corresponding to an HTTP2 connection. - - Args: - protocol (H2Protocol): The protocol instance used to send frames. - loop (asyncio.AbstractEventLoop): The asyncio event loop. - """ - - def __init__(self, protocol, loop: asyncio.AbstractEventLoop): - # The protocol instance used to send frames. - self.protocol: H2Protocol = protocol - - # The asyncio event loop. - self.loop = loop - - # Queue for storing data to be sent out - self._outbound_data_queue: asyncio.Queue[Tuple[H2Frame, asyncio.Event]] = ( - asyncio.Queue() - ) - - # Dictionary for storing data that could not be sent due to flow control limits - self._flow_control_data: Dict[int, Tuple[H2Frame, asyncio.Event]] = {} - - # Set of streams that need to be reset - self._reset_streams = set() - - # Task for the data sender loop. - self._data_sender_loop_task = None - - def start(self) -> None: - """ - Start the data sender loop. - This creates and starts an asyncio task that runs the _data_sender_loop coroutine. - """ - # Start the data sender loop - self._data_sender_loop_task = self.loop.create_task(self._data_sender_loop()) - - def cancel(self) -> None: - """ - Cancel the data sender loop. - This cancels the asyncio task running the _data_sender_loop coroutine. - """ - if self._data_sender_loop_task: - self._data_sender_loop_task.cancel() - - def put(self, frame: H2Frame, event: asyncio.Event) -> None: - """ - Put a data frame into the outbound data queue. - - Args: - frame (H2Frame): The data frame to send. - event (asyncio.Event): The event to notify when the data frame is sent. - """ - self._outbound_data_queue.put_nowait((frame, event)) - - def release(self, frame: H2Frame) -> None: - """ - Release the flow control for the stream. - - Args: - frame (H2Frame): The data frame to release the flow control. - It must be a WINDOW_UPDATE frame. - """ - if frame.frame_type != H2FrameType.WINDOW_UPDATE: - raise TypeError("The frame is not a window update frame") - - stream_id = frame.stream_id - if stream_id: - # This is specific to a single stream. - if stream_id in self._flow_control_data: - data_frame_event = self._flow_control_data.pop(stream_id) - self._outbound_data_queue.put_nowait(data_frame_event) - else: - # This is for the entire connection. - for data_frame_event in self._flow_control_data.values(): - self._outbound_data_queue.put_nowait(data_frame_event) - # Clear the pending data - self._flow_control_data = {} - - def reset(self, frame: H2Frame) -> None: - """ - Reset the stream. - - Args: - frame (H2Frame): The reset frame. It must be an RST_STREAM frame. - """ - if frame.frame_type != H2FrameType.RST_STREAM: - raise TypeError("The frame is not a reset stream frame") - - if frame.stream_id in self._flow_control_data: - del self._flow_control_data[frame.stream_id] - - self._reset_streams.add(frame.stream_id) - - async def _data_sender_loop(self) -> None: - """ - Coroutine that continuously sends data frames from the outbound data queue - while respecting flow control limits. - """ - while True: - # Get the frame from the outbound data queue -> it's a blocking operation, but asynchronous. - data_frame: H2Frame - event: asyncio.Event - data_frame, event = await self._outbound_data_queue.get() - - # If the frame is not a data frame, ignore it. - if data_frame.frame_type != H2FrameType.DATA: - logger.warning(f"Invalid frame type: {data_frame.frame_type}, ignored") - event.set() - continue - - # Get the stream ID and data from the frame. - stream_id = data_frame.stream_id - data = data_frame.data - end_stream = data_frame.end_stream - - # The stream has been reset, so we don't send any data. - if stream_id in self._reset_streams: - event.set() - continue - - # We need to send data, but not to exceed the flow control window. - window_size = self.protocol.conn.local_flow_control_window(stream_id) - chunk_size = min(window_size, len(data)) - data_to_send = data[:chunk_size] - data_to_buffer = data[chunk_size:] - - if data_to_send: - # Send the data frame - max_size = self.protocol.conn.max_outbound_frame_size - - # Split the data into chunks and send them out - for x in range(0, len(data), max_size): - chunk = data[x : x + max_size] - end_stream_flag = ( - end_stream - and data_to_buffer == b"" - and x + max_size >= len(data) - ) - self.protocol.conn.send_data( - stream_id, chunk, end_stream=end_stream_flag - ) - - self.protocol.transport.write(self.protocol.conn.data_to_send()) - elif end_stream: - # If there is no data to send, but the stream is ended, send an empty data frame. - self.protocol.conn.send_data(stream_id, b"", end_stream=True) - self.protocol.transport.write(self.protocol.conn.data_to_send()) - - if data_to_buffer: - # Store the data that could not be sent due to flow control limits - data_frame.data = data_to_buffer - self._flow_control_data[stream_id] = (data_frame, event) - else: - # We sent everything. - event.set() - - -class H2Protocol(asyncio.Protocol): - """ - Implements an HTTP/2 protocol using asyncio's Protocol class. - - This class sets up and manages an HTTP/2 connection using the h2 library. - It handles connection state, stream mapping, and data flow control. - - Args: - url (URL): The URL object that contains the connection parameters. - """ - - def __init__(self, url: URL): - self.url = url - # Create the H2 state machine - client_side = ( - self.url.parameters.get( - common_constants.TRANSPORTER_SIDE_KEY, - common_constants.TRANSPORTER_SIDE_CLIENT, - ) - == common_constants.TRANSPORTER_SIDE_CLIENT - ) - h2_config = H2Configuration(client_side=client_side, header_encoding="utf-8") - self.conn: H2Connection = H2Connection(config=h2_config) - - # the backing transport. - self.transport: Optional[asyncio.Transport] = None - - # The asyncio event loop. - self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() - - # A mapping of stream ID to stream object. - self._stream_handler: StreamHandler = self.url.attributes["stream_handler"] - - self._data_follow_control: Optional[DataFlowControl] = None - - def connection_made(self, transport: asyncio.Transport) -> None: - """ - Called when the connection is first established. We complete the following actions: - 1. Save the transport. - 2. Initialize the H2 connection. - 3. Initialize the StreamHandler. - 3. Create the data follow control and start the task. - """ - self.transport = transport - self.conn.initiate_connection() - self.transport.write(self.conn.data_to_send()) - - # Initialize the StreamHandler - self._stream_handler.init(self.loop, self) - - # Create the data follow control object and start the task. - self._data_follow_control = DataFlowControl(self, self.loop) - self._data_follow_control.start() - - def connection_lost(self, exc) -> None: - """ - Called when the connection is lost. - Args: - exc: The exception that caused the connection to be lost. - """ - self._stream_handler.destroy() - self._data_follow_control.cancel() - - # Handle the connection close event - if on_conn_lost := self.url.attributes.get( - common_constants.TRANSPORTER_ON_CONN_CLOSE_KEY - ): - if isinstance(on_conn_lost, (asyncio.Event, threading.Event)): - on_conn_lost.set() - elif isinstance(on_conn_lost, (asyncio.Future, ThreadingFuture)): - on_conn_lost.set_result(exc) - elif callable(on_conn_lost): - on_conn_lost(exc) - else: - logger.error("Unable to handle the connection close event") - - def send_headers_frame(self, headers_frame: H2Frame) -> asyncio.Event: - """ - Send headers to the remote peer. (thread-safe) - Note: - Only the first call sends a head frame, if called again, a trailer frame is sent. - Args: - headers_frame(H2Frame): The headers frame to send. - Returns: - asyncio.Event: The event that is set when the headers frame is sent. - """ - headers_event = asyncio.Event() - - def _inner_send_headers_frame(_headers_frame: H2Frame, event: asyncio.Event): - self.conn.send_headers( - _headers_frame.stream_id, _headers_frame.data, _headers_frame.end_stream - ) - self.transport.write(self.conn.data_to_send()) - # Set the event to indicate that the headers frame has been sent. - event.set() - - # Send the header frame - self.loop.call_soon_threadsafe( - _inner_send_headers_frame, headers_frame, headers_event - ) - - return headers_event - - def send_data_frame(self, data_frame: H2Frame) -> asyncio.Event: - """ - Send data to the remote peer. (thread-safe) - The sending of data frames is subject to traffic control. - Args: - data_frame(H2Frame): The data frame to send. - Returns: - asyncio.Event: The event that is set when the data frame is sent. - """ - data_event = asyncio.Event() - - def _inner_send_data_frame(_data_frame: H2Frame, event: asyncio.Event): - self._data_follow_control.put(_data_frame, event) - - self.loop.call_soon_threadsafe(_inner_send_data_frame, data_frame, data_event) - - return data_event - - def send_reset_frame(self, reset_frame: H2Frame) -> None: - """ - Send the reset frame to the remote peer.(thread-safe) - Args: - reset_frame(H2Frame): The reset frame to send. - """ - - def _inner_send_reset_frame(_reset_frame: H2Frame): - self.conn.reset_stream(_reset_frame.stream_id, _reset_frame.data) - self.transport.write(self.conn.data_to_send()) - # remove the stream from the stream handler - self._stream_handler.remove(_reset_frame.stream_id) - - self.loop.call_soon_threadsafe(_inner_send_reset_frame, reset_frame) - - def data_received(self, data: bytes) -> None: - """ - Process inbound data. - """ - events = self.conn.receive_data(data) - # Process the event - for event in events: - frame = H2FrameUtils.create_frame_by_event(event) - if not frame: - # If frame is None, there are two possible cases: - # 1. Events that are handled automatically by the H2 library (e.g. RemoteSettingsChanged, PingReceived). - # -> We just need to send it. - # 2. Events that are not implemented or do not require attention. -> We'll ignore it for now. - pass - else: - # The frames we focus on include: HEADERS, DATA, WINDOW_UPDATE, RST_STREAM - if frame.frame_type == H2FrameType.WINDOW_UPDATE: - # Update the flow control window - self._data_follow_control.release(frame) - else: - if frame.frame_type == H2FrameType.RST_STREAM: - # Reset the stream - self._data_follow_control.reset(frame) - # Handle the frame - self._stream_handler.handle_frame(frame) - - # Acknowledge the received data - if frame.frame_type == H2FrameType.DATA: - self.conn.acknowledge_received_data( - frame.attributes["flow_controlled_length"], frame.stream_id - ) - - # If there is data to send, send it. - outbound_data = self.conn.data_to_send() - if outbound_data: - self.transport.write(outbound_data) diff --git a/dubbo/remoting/aio/h2_stream.py b/dubbo/remoting/aio/h2_stream.py deleted file mode 100644 index 05deadd..0000000 --- a/dubbo/remoting/aio/h2_stream.py +++ /dev/null @@ -1,423 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -from typing import List, Optional, Tuple - -from dubbo.logger.logger_factory import loggerFactory -from dubbo.remoting.aio.h2_frame import (DATA_COMPLETED_FRAME, H2Frame, - H2FrameType, H2FrameUtils) - -logger = loggerFactory.get_logger(__name__) - - -class StreamFrameControl: - """ - This class is responsible for controlling the order and sending of frames in an HTTP/2 stream. - It ensures that frames are sent in the correct sequence, specifically HEADERS, DATA (0 or more), - and optional TRAILERS. - - Note: - 1. - This class is not thread-safe and does not need to be designed as thread-safe because it - is used only within a single Stream object. However, asynchronous call safety must be ensured. - 2. Special frames like RESET can be sent without following this sequence. - 3. Each Stream object corresponds to a StreamFrameControl object. - - - Args: - protocol(H2Protocol): The protocol instance used to send frames. - loop(asyncio.AbstractEventLoop): The asyncio event loop. - """ - - def __init__(self, protocol, loop: asyncio.AbstractEventLoop): - # Import here to avoid looping imports - from dubbo.remoting.aio.h2_protocol import H2Protocol - - # The protocol instance used to send frames. - self._protocol: H2Protocol = protocol - - # The asyncio event loop. - self._loop = loop - - # The queue for storing frames - # HEADERS: 0, DATA: 1, TRAILERS: 2 - self._frame_queue = asyncio.PriorityQueue() - - # The event for the start of the stream -> Ensure that HEADERS frame have been placed in the queue - self._start_event: asyncio.Event = asyncio.Event() - - # The event for the headers frame -> Ensure that HEADERS frame have been sent - self._headers_event: Optional[asyncio.Event] = None - - # The event for the data frame -> Ensure that previous DATA frame have been sent - self._data_event: Optional[asyncio.Event] = None - - # The flag to indicate whether the data is completed -> Ensure that all data frames have been placed in the queue - self._data_completed = False - - # TRAILERS frame storage - self._trailers_frame: Optional[H2Frame] = None - - self._frame_sender_loop_task = None - - def start(self): - """ - Start the frame sender loop. - This creates and starts an asyncio task that runs the _frame_sender_loop coroutine. - """ - self._frame_sender_loop_task = self._loop.create_task(self._frame_sender_loop()) - - def cancel(self): - """ - Cancel the frame sender loop. - This cancels the asyncio task running the _frame_sender_loop coroutine. - """ - if self._frame_sender_loop_task: - self._frame_sender_loop_task.cancel() - - def put_headers(self, headers_frame: H2Frame): - """ - Put a HEADERS frame into the frame queue. - - Args: - headers_frame (H2Frame): The HEADERS frame to be added. - - Raises: - TypeError: If the frame is not a HEADERS frame. - """ - if headers_frame.frame_type != H2FrameType.HEADERS: - raise TypeError("The frame is not a HEADERS frame") - - # If the start event is not set, set it. - if not self._start_event.is_set(): - # HEADERS - self._frame_queue.put_nowait((0, headers_frame)) - self._start_event.set() - else: - # TRAILERS - self.put_trailers_later(headers_frame) - - def put_data(self, data_frame: H2Frame): - """ - Put a DATA frame into the frame queue. - - Args: - data_frame (H2Frame): The DATA frame to be added. - - Raises: - TypeError: If the frame is not a DATA frame. - RuntimeError: If the data is completed, no more data can be sent. - """ - if data_frame.frame_type != H2FrameType.DATA: - raise TypeError("The frame is not a DATA frame") - elif self._data_completed: - raise RuntimeError("The data is completed, no more data can be sent.") - - if data_frame == DATA_COMPLETED_FRAME: - # The data is completed - self._data_completed = True - if self._trailers_frame: - # Make sure TRAILERS are sent after DATA - self.put_trailers_now(self._trailers_frame) - else: - self._data_completed = data_frame.end_stream - self._frame_queue.put_nowait((1, data_frame)) - - def put_trailers_now(self, trailers_frame: H2Frame): - """ - Immediately put a TRAILERS frame into the frame queue. - - Note: You should call this method when you don't need to send DATA. - - Args: - trailers_frame (H2Frame): The TRAILERS frame to be added. - - Raises: - TypeError: If the frame is not a HEADERS frame. - """ - if trailers_frame.frame_type != H2FrameType.HEADERS: - raise TypeError("The frame is not a HEADERS frame") - - self._frame_queue.put_nowait((2, trailers_frame)) - - def put_trailers_later(self, trailers_frame: H2Frame): - """ - Store the TRAILERS frame to be sent after all DATA frames. - - Note: When you need to send DATA, you should call this method. - - Args: - trailers_frame (H2Frame): The TRAILERS frame to be stored. - - Raises: - TypeError: If the frame is not a HEADERS frame. - """ - self._trailers_frame = trailers_frame - - async def _frame_sender_loop(self): - """ - The main loop for sending frames. This loop continuously fetches frames from the queue and sends them in the - correct order. - - It ensures that HEADERS frames are sent before any DATA frames, and waits for the completion events of HEADERS - and DATA frames before sending subsequent frames. - - If a frame has the end_stream flag set, the loop breaks, indicating the end of the stream. - """ - while True: - # Wait for the start event - await self._start_event.wait() - - # Get the frame from the outbound data queue -> it's a blocking operation, but asynchronous. - priority, frame = await self._frame_queue.get() - - # If the frame is HEADERS, send the header frame directly. - if frame.frame_type == H2FrameType.HEADERS and not self._headers_event: - self._headers_event = self._protocol.send_headers_frame(frame) - else: - # Wait for HEADERS to be sent. - await self._headers_event.wait() - - # Waiting for the previous DATA to be sent. - if self._data_event: - await self._data_event.wait() - - if frame.frame_type == H2FrameType.DATA: - # Send the data frame and store the event. - self._data_event = self._protocol.send_data_frame(frame) - elif frame.frame_type == H2FrameType.HEADERS: - # Send the trailers frame. - self._protocol.send_headers_frame(frame) - - if frame.end_stream: - # The stream is completed. we can break the loop. - break - - -class Stream: - """ - Stream is a bidirectional channel that manipulates the data flow between peers. - - This class manages the sending and receiving of HTTP/2 frames for a single stream. - It ensures frames are sent in the correct order and handles flow control for the stream. - - Args: - stream_id (int): The stream identifier. - listener (Stream.Listener): The listener for the stream to handle the received frames. - loop (asyncio.AbstractEventLoop): The asyncio event loop. - protocol (H2Protocol): The protocol instance used to send frames. - - """ - - def __init__( - self, - stream_id: int, - listener: "Stream.Listener", - loop: asyncio.AbstractEventLoop, - protocol, - ): - # import here to avoid circular import - from dubbo.remoting.aio.h2_protocol import H2Protocol - - # The stream ID. - self._stream_id: int = stream_id - # The listener for the stream to handle the received frames. - self._listener: "Stream.Listener" = listener - - # The protocol. - self._protocol: H2Protocol = protocol - - # The asyncio event loop. - self._loop = loop - - # The stream frame control. - self._stream_frame_control = StreamFrameControl(protocol, loop) - self._stream_frame_control.start() - - # The flag to indicate whether the sending is completed. - self._send_completed = False - - # The flag to indicate whether the receiving is completed. - self._receive_completed = False - - def send_headers( - self, headers: List[Tuple[str, str]], end_stream: bool = False - ) -> None: - """ - Send the headers frame. The first call sends the head frame, the second call sends the trailer frame. - - Args: - headers (List[Tuple[str, str]]): The headers to send. - end_stream (bool): Whether to end the stream after sending this frame. - """ - if self._send_completed: - return - else: - self._send_completed = end_stream - - def _inner_send_headers(_headers: List[Tuple[str, str]], _end_stream: bool): - headers_frame = H2FrameUtils.create_headers_frame( - self._stream_id, _headers, _end_stream - ) - self._stream_frame_control.put_headers(headers_frame) - - self._loop.call_soon_threadsafe(_inner_send_headers, headers, end_stream) - # Try to close the stream - self.try_close() - - def send_data(self, data: bytes, end_stream: bool = False) -> None: - """ - Send a data frame. - - Args: - data (bytes): The data to send. - end_stream (bool): Whether to end the stream after sending this frame. - """ - if self._send_completed: - return - else: - self._send_completed = end_stream - - def _inner_send_data(_data: bytes, _end_stream: bool): - data_frame = H2FrameUtils.create_data_frame( - self._stream_id, _data, _end_stream - ) - self._stream_frame_control.put_data(data_frame) - - self._loop.call_soon_threadsafe(_inner_send_data, data, end_stream) - # Try to close the stream - self.try_close() - - def send_data_completed(self) -> None: - """ - Indicates that the data frame has been fully sent, but other frames (such as trailers) may still need to be sent. - """ - - def _inner_send_data_completed(): - self._stream_frame_control.put_data(DATA_COMPLETED_FRAME) - - self._loop.call_soon_threadsafe(_inner_send_data_completed) - - def send_reset(self, error_code: int) -> None: - """ - Send a reset frame to terminate the stream. - - Note: This is a special frame and does not need to follow the sequence of frames. - - Args: - error_code (int): The error code indicating the reason for the reset. - """ - self._send_completed = True - - def _inner_send_reset(_error_code: int): - reset_frame = H2FrameUtils.create_reset_stream_frame( - self._stream_id, _error_code - ) - self._protocol.send_reset_frame(reset_frame) - self._stream_frame_control.cancel() - - self._loop.call_soon_threadsafe(_inner_send_reset, error_code) - - # Close the stream immediately. - self.close() - - def receive_headers(self, headers: List[Tuple[str, str]]) -> None: - """ - Called when a headers frame is received. - - Args: - headers (List[Tuple[str, str]]): The headers received. - """ - self._listener.on_headers(headers) - - def receive_data(self, data: bytes) -> None: - """ - Called when a data frame is received. - - Args: - data (bytes): The data received. - """ - self._listener.on_data(data) - - def receive_complete(self) -> None: - """ - Called when the stream is completed. - """ - self._receive_completed = True - # notify the listener - self._listener.on_complete() - # Try to close the stream - self.try_close() - - def receive_reset(self, err_code: int) -> None: - """ - Called when the stream is cancelled by the remote peer. - - Args: - err_code (int): The error code indicating the reason for cancellation. - """ - self._listener.on_reset(err_code) - - def try_close(self) -> None: - """ - Try to close the stream. - """ - if self._send_completed and self._receive_completed: - self.close() - - def close(self) -> None: - """ - Close the stream by cancelling the frame sender loop. - """ - self._stream_frame_control.cancel() - - class Listener: - """ - The listener for the stream to handle the received frames. - """ - - def on_headers(self, headers: List[Tuple[str, str]]) -> None: - """ - Called when a headers frame is received. - - Args: - headers (List[Tuple[str, str]]): The headers received. - """ - raise NotImplementedError("on_headers() is not implemented") - - def on_data(self, data: bytes) -> None: - """ - Called when a data frame is received. - - Args: - data (bytes): The data received. - """ - raise NotImplementedError("on_data() is not implemented") - - def on_complete(self) -> None: - """ - Called when the stream is completed. - """ - raise NotImplementedError("on_complete() is not implemented") - - def on_reset(self, err_code: int) -> None: - """ - Called when the stream is cancelled by the remote peer. - - Args: - err_code (int): The error code indicating the reason for cancellation. - """ - raise NotImplementedError("on_reset() is not implemented") diff --git a/dubbo/remoting/aio/h2_stream_handler.py b/dubbo/remoting/aio/h2_stream_handler.py deleted file mode 100644 index 9142eb9..0000000 --- a/dubbo/remoting/aio/h2_stream_handler.py +++ /dev/null @@ -1,181 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -from concurrent.futures import Future as ThreadingFuture -from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Optional, Tuple - -from dubbo.logger.logger_factory import loggerFactory -from dubbo.remoting.aio.h2_frame import H2Frame, H2FrameType -from dubbo.remoting.aio.h2_stream import Stream - -logger = loggerFactory.get_logger(__name__) - - -class StreamHandler: - """ - Stream handler class. It is used to handle the stream in the connection. - Args: - executor(ThreadPoolExecutor): The executor to handle the frame. - """ - - def __init__( - self, - executor: Optional[ThreadPoolExecutor] = None, - ): - # import here to avoid circular import - from dubbo.remoting.aio.h2_protocol import H2Protocol - - self._protocol: Optional[H2Protocol] = None - - # The event loop to run the asynchronous function. - self._loop: Optional[asyncio.AbstractEventLoop] = None - - # The streams managed by the handler - self._streams: Dict[int, Stream] = {} - - # The executor to handle the frame, If None, the default executor will be used. - self._executor = executor - - def init(self, loop: asyncio.AbstractEventLoop, protocol) -> None: - """ - Initialize the handler with the protocol. - Args: - loop(asyncio.AbstractEventLoop): The event loop. - protocol(H2Protocol): The protocol. - """ - self._loop = loop - self._protocol = protocol - self._streams.clear() - - def handle_frame(self, frame: H2Frame) -> None: - """ - Handle the frame received from the connection. - Args: - frame: The frame to handle. - """ - # Handle the frame in the executor - self._loop.run_in_executor(self._executor, self._handle_in_executor, frame) - - def _handle_in_executor(self, frame: H2Frame) -> None: - """ - Actually handle the frame in the executor. - Args: - frame: The frame to handle. - """ - stream = self._streams.get(frame.stream_id) - - if not stream: - logger.warning(f"Unknown stream: id={frame.stream_id}") - return - - frame_type = frame.frame_type - if frame_type == H2FrameType.HEADERS: - stream.receive_headers(frame.data) - elif frame_type == H2FrameType.DATA: - stream.receive_data(frame.data) - elif frame_type == H2FrameType.RST_STREAM: - stream.receive_reset(frame.data) - else: - logger.debug(f"Unhandled frame: {frame_type}") - - if frame.end_stream: - stream.receive_complete() - - def create(self, listener: Stream.Listener) -> Stream: - """ - Create a new stream. -> Client - Args: - listener: The listener to the stream. - Returns: - Stream: The new stream. - """ - raise NotImplementedError("create() is not implemented") - - def register(self, stream_id: int) -> None: - """ - Register the stream to the handler -> Server - Args: - stream_id: The stream ID. - """ - raise NotImplementedError("register() is not implemented") - - def remove(self, stream_id: int) -> None: - """ - Remove the stream from the handler -> Server - Args: - stream_id: The stream ID. - """ - del self._streams[stream_id] - - def destroy(self) -> None: - """ - Destroy the handler. - """ - for stream in self._streams.values(): - stream.close() - self._streams.clear() - - -class ClientStreamHandler(StreamHandler): - - def create(self, listener: Stream.Listener) -> Stream: - """ - Create a new stream. -> Client - Args: - listener: The listener to the stream. - Returns: - Stream: The new stream. - """ - # Create a new client stream - future = ThreadingFuture() - - def _inner_create(_future: ThreadingFuture): - new_stream_id = self._protocol.conn.get_next_available_stream_id() - new_stream = Stream(new_stream_id, listener, self._loop, self._protocol) - self._streams[new_stream_id] = new_stream - _future.set_result(new_stream) - - self._loop.call_soon_threadsafe(_inner_create, future) - # Return the stream and the listener - return future.result() - - -class ServerStreamHandler(StreamHandler): - - def register(self, stream_id: int) -> Tuple[Stream, Stream.Listener]: - """ - Register the stream to the handler -> Server - Args: - stream_id: The stream ID. - Returns: - (Stream, Stream.Listener): A tuple containing the stream and the listener. - """ - # TODO Create a new listener - new_listener = Stream.Listener() - new_stream = Stream(stream_id, new_listener, self._loop, self._protocol) - self._streams[stream_id] = new_stream - # Return the stream and the listener - return new_stream, new_listener - - def handle_frame(self, frame: H2Frame) -> None: - # Register the stream if it is a HEADERS frame and the stream is not registered. - if ( - frame.frame_type == H2FrameType.HEADERS - and frame.stream_id not in self._streams - ): - self.register(frame.stream_id) - super().handle_frame(frame) diff --git a/dubbo/protocol/triple/tri_listener.py b/dubbo/remoting/aio/http2/__init__.py similarity index 67% rename from dubbo/protocol/triple/tri_listener.py rename to dubbo/remoting/aio/http2/__init__.py index 5f1ab3e..bcba37a 100644 --- a/dubbo/protocol/triple/tri_listener.py +++ b/dubbo/remoting/aio/http2/__init__.py @@ -13,21 +13,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple - -from dubbo.remoting.aio.h2_stream import Stream - - -class TriClientStreamListener(Stream.Listener): - - def on_headers(self, headers: List[Tuple[str, str]]) -> None: - pass - - def on_data(self, data: bytes) -> None: - pass - - def on_complete(self) -> None: - pass - - def on_reset(self, err_code: int) -> None: - pass diff --git a/dubbo/remoting/aio/http2/controllers.py b/dubbo/remoting/aio/http2/controllers.py new file mode 100644 index 0000000..0534bea --- /dev/null +++ b/dubbo/remoting/aio/http2/controllers.py @@ -0,0 +1,348 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import threading +from dataclasses import dataclass +from typing import Dict, Optional, Union + +from h2.connection import H2Connection + +from dubbo.logger.logger_factory import loggerFactory +from dubbo.remoting.aio.http2.frames import DataFrame, HeadersFrame, Http2Frame +from dubbo.remoting.aio.http2.registries import Http2FrameType +from dubbo.remoting.aio.http2.stream import Http2Stream + +logger = loggerFactory.get_logger(__name__) + + +class FollowController: + """ + HTTP/2 stream flow controller. + Note: + This is a thread-unsafe class and must be used in the Http2Protocol class + + Args: + loop: The asyncio event loop. + h2_connection: The H2 connection. + transport: The asyncio transport. + """ + + @dataclass + class StreamItem: + """ + The item for storing stream, flag, and event. + Args: + stream: The stream. + half_close: Whether to close the stream after sending the data. + event: This event is triggered when all data has been sent. + """ + + stream: Http2Stream + half_close: bool + event: asyncio.Event + + def __init__( + self, + loop: asyncio.AbstractEventLoop, + h2_connection: H2Connection, + transport: asyncio.Transport, + ): + self._loop = loop + self._h2_connection = h2_connection + self._transport = transport + + # Collection of all streams that need to send data + self._stream_dict: Dict[int, FollowController.StreamItem] = {} + + # Collection of streams that are currently sending data + self._outbound_stream_queue: asyncio.Queue[FollowController.StreamItem] = ( + asyncio.Queue() + ) + + # Collection of streams that are flow-controlled + self._follow_control_dict: Dict[int, FollowController.StreamItem] = {} + + # Actual storage for the data that needs to be sent + self._data_dict: Dict[int, bytearray] = {} + + # The task for sending data. + self._task = None + + def start(self) -> None: + """ + Start the data sender loop. + This creates and starts an asyncio task that runs the _data_sender_loop coroutine. + """ + self._task = self._loop.create_task(self._send_data()) + + def increment_flow_control_window(self, stream_id: Optional[int]) -> None: + """ + Increment the flow control window size. + Args: + stream_id: The stream identifier. If it is None, it means the entire connection. + """ + if stream_id is None or stream_id == 0: + # This is for the entire connection. + for item in self._follow_control_dict.values(): + self._outbound_stream_queue.put_nowait(item) + self._follow_control_dict = {} + elif stream_id in self._follow_control_dict: + # This is specific to a single stream. + item = self._follow_control_dict.pop(stream_id) + self._outbound_stream_queue.put_nowait(item) + + def send_data( + self, + stream: Http2Stream, + data: bytes, + half_close: bool, + event: Union[asyncio.Event, threading.Event] = None, + ): + """ + Send data to the stream.(thread-unsafe) + Note: + Args: + stream: The stream. + data: The data to send. + half_close: Whether to close the stream after sending the data. + event: The event that is triggered when all data has been sent. + """ + + # Check if the stream is closed + if stream.is_local_closed(): + if event: + event.set() + logger.warning(f"Stream {stream.id} is closed. Ignoring data {data}") + else: + # Save the data to the data dictionary + if old_data := self._data_dict.get(stream.id): + old_data.extend(data) + item = self._stream_dict[stream.id] + item.half_close = half_close + # Update the event + if item.event: + item.event.set() + item.event = event + else: + self._data_dict[stream.id] = bytearray(data) + self._stream_dict[stream.id] = FollowController.StreamItem( + stream, half_close, event + ) + + # Put the stream into the outbound stream queue + self._outbound_stream_queue.put_nowait(self._stream_dict[stream.id]) + + def stop(self) -> None: + """ + Stop the data sender loop. + This cancels the asyncio task that runs the _data_sender_loop coroutine. + """ + if self._task: + self._task.cancel() + + async def _send_data(self) -> None: + """ + Coroutine that continuously sends data frames from the outbound data queue while respecting flow control limits. + """ + while True: + # get the data to send.(async blocking) + item = await self._outbound_stream_queue.get() + + # check if the stream is closed + stream = item.stream + if stream.is_local_closed(): + # The local side of the stream is closed, so we don't need to send any data. + if item.event: + item.event.set() + continue + + # get the flow control window size + data = self._data_dict.get(stream.id, bytearray()) + window_size = self._h2_connection.local_flow_control_window(stream.id) + chunk_size = min(window_size, len(data)) + data_to_send = data[:chunk_size] + data_to_buffer = data[chunk_size:] + + # send the data + if data_to_send or item.half_close: + max_size = self._h2_connection.max_outbound_frame_size + # Split the data into chunks and send them out + for x in range(0, len(data_to_send), max_size): + chunk = data_to_send[x : x + max_size] + end_stream_flag = ( + item.half_close + and data_to_buffer == b"" + and x + max_size >= len(data_to_send) + ) + self._h2_connection.send_data( + stream.id, chunk, end_stream=end_stream_flag + ) + + outbound_data = self._h2_connection.data_to_send() + if not outbound_data: + # If there is no outbound data to send but the stream needs to be closed, + # send an empty headers frame with the end_stream flag set to True. + self._h2_connection.send_data(stream.id, b"", end_stream=True) + outbound_data = self._h2_connection.data_to_send() + self._transport.write(outbound_data) + + if data_to_buffer: + # Save the data that could not be sent due to flow control limits + self._follow_control_dict[stream.id] = item + self._data_dict[stream.id] = data_to_buffer + else: + # If all data has been sent, trigger the event. + self._data_dict.pop(stream.id) + if item.event: + item.event.set() + + +class FrameOrderController: + """ + HTTP/2 frame writer. This class is responsible for writing frames in the correct order. + Note: + Some special frames do not need to be sorted through this queue, such as RST_STREAM, WINDOW_UPDATE, etc. + Args: + stream: The stream to which the frame belongs. + loop: The asyncio event loop. + protocol: The HTTP/2 protocol. + """ + + def __init__(self, stream: Http2Stream, loop: asyncio.AbstractEventLoop, protocol): + from dubbo.remoting.aio.http2.protocol import Http2Protocol + + self._stream: Http2Stream = stream + self._loop: asyncio.AbstractEventLoop = loop + self._protocol: Http2Protocol = protocol + + # The queue for writing frames. -> keep the order of frames + self._frame_queue: asyncio.PriorityQueue = asyncio.PriorityQueue() + # The task for writing frames. + self._send_frame_task: Optional[asyncio.Task] = None + + # some events + # This event is triggered when a HEADERS frame is placed in the queue. + self._start_event = asyncio.Event() + # This event is triggered when the headers are sent. + self._headers_sent_event: Optional[asyncio.Event] = None + # This event is triggered when the data is sent. + self._data_sent_event: Optional[asyncio.Event] = None + + # The trailers frame. + self._trailers: Optional[HeadersFrame] = None + + def start(self) -> None: + """ + Start the frame writer loop. + This creates and starts an asyncio task that runs the _frame_writer_loop coroutine. + """ + self._send_frame_task = self._loop.create_task(self._write_frame()) + + def write_headers(self, frame: HeadersFrame) -> None: + """ + Write the headers frame to the frame writer queue.(thread-safe) + Args: + frame: The headers frame. + """ + + def _inner_operation(_frame: Http2Frame): + # put the frame into the queue + self._frame_queue.put_nowait((0, _frame)) + # trigger the start event + self._start_event.set() + + self._loop.call_soon_threadsafe(_inner_operation, frame) + + def write_data(self, frame: DataFrame, last: bool = False) -> None: + """ + Write the data frame to the frame writer queue.(thread-safe) + Args: + frame: The data frame. + last: Unlike end_stream, this flag indicates whether the current frame is the last data frame or not. + """ + + def _inner_operation(_frame: Http2Frame, _last: bool): + # put the frame into the queue + self._frame_queue.put_nowait((1, _frame)) + if _last: + # put the trailers frame into the queue + if self._trailers: + self._frame_queue.put_nowait((2, self._trailers)) + + self._loop.call_soon_threadsafe(_inner_operation, frame, last) + + def write_trailers(self, frame: HeadersFrame) -> None: + """ + Write the trailers frame to the frame writer queue.(thread-safe) + Note: + This method is suitable for cases where data frames are not to be sent + Args: + frame: The trailers frame. + """ + + def _inner_operation(_frame: Http2Frame): + # put the frame into the queue + self._frame_queue.put_nowait((2, _frame)) + + self._loop.call_soon_threadsafe(_inner_operation, frame) + + def write_trailers_after_data(self, frame: HeadersFrame) -> None: + """ + Write the trailers frame to the frame writer queue.(thread-safe) + Note: + This method is used to write trailers after the data frame. + If the data frame is not sent completely, the trailers frame will not be sent. + """ + self._trailers = frame + + async def _write_frame(self) -> None: + """ + Coroutine that continuously writes frames from the frame queue. + """ + while True: + # wait for the start event + await self._start_event.wait() + + # get the frame from the queue -> block & async + _, frame = await self._frame_queue.get() + + # write the frame + if frame.frame_type == Http2FrameType.HEADERS: + self._headers_sent_event = self._protocol.write(frame, self._stream) + else: + # await the headers sent event + await self._headers_sent_event.wait() + + # await the data sent event + if self._data_sent_event: + await self._data_sent_event.wait() + + self._data_sent_event = self._protocol.write(frame, self._stream) + + # check if the frame is the last frame + if frame.end_stream: + # close the stream + if frame.frame_type != Http2FrameType.DATA: + self._stream.close_local() + break + + def stop(self) -> None: + """ + Stop the frame writer loop. + This cancels the asyncio task that runs the _frame_writer_loop coroutine. + """ + if self._send_frame_task: + self._send_frame_task.cancel() diff --git a/dubbo/remoting/aio/http2/frames.py b/dubbo/remoting/aio/http2/frames.py new file mode 100644 index 0000000..173e29b --- /dev/null +++ b/dubbo/remoting/aio/http2/frames.py @@ -0,0 +1,134 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +from dubbo.remoting.aio.http2.headers import Http2Headers +from dubbo.remoting.aio.http2.registries import Http2ErrorCode, Http2FrameType + + +class Http2Frame: + """ + HTTP/2 frame class. It is used to represent an HTTP/2 frame. + Args: + stream_id: The stream identifier. + frame_type: The frame type. + """ + + def __init__( + self, + stream_id: int, + frame_type: Http2FrameType, + end_stream: bool = False, + ): + self.stream_id = stream_id + self.frame_type = frame_type + self.end_stream = end_stream + + # The timestamp of the generated frame. -> comparison for Priority Queue + self.timestamp = int(round(time.time() * 1000)) + + def __lt__(self, other: "Http2Frame") -> bool: + return self.timestamp <= other.timestamp + + def __repr__(self) -> str: + return f"" + + +class HeadersFrame(Http2Frame): + """ + HTTP/2 headers frame. + Args: + stream_id: The stream identifier. + headers: The HTTP/2 headers. + end_stream: Whether the stream is ended. + """ + + def __init__( + self, + stream_id: int, + headers: Http2Headers, + end_stream: bool = False, + ): + super().__init__(stream_id, Http2FrameType.HEADERS, end_stream) + self.headers = headers + + def __repr__(self) -> str: + return f"" + + +class DataFrame(Http2Frame): + """ + HTTP/2 data frame. + Args: + stream_id: The stream identifier. + data: The data to send. + data_length: The amount of data received that counts against the flow control window. + end_stream: Whether the stream + """ + + def __init__( + self, + stream_id: int, + data: bytes, + data_length: int, + end_stream: bool = False, + ): + super().__init__(stream_id, Http2FrameType.DATA, end_stream) + self.data = data + self.data_length = data_length + + def __repr__(self) -> str: + return f"" + + +class WindowUpdateFrame(Http2Frame): + """ + HTTP/2 window update frame. + Args: + stream_id: The stream identifier. + delta: The number of bytes by which to increase the flow control window. + """ + + def __init__( + self, + stream_id: int, + delta: int, + ): + super().__init__(stream_id, Http2FrameType.WINDOW_UPDATE, False) + self.delta = delta + + def __repr__(self) -> str: + return f"" + + +class ResetStreamFrame(Http2Frame): + """ + HTTP/2 reset stream frame. + Args: + stream_id: The stream identifier. + error_code: The error code that indicates the reason for closing the stream. + """ + + def __init__( + self, + stream_id: int, + error_code: Http2ErrorCode, + ): + super().__init__(stream_id, Http2FrameType.RST_STREAM, True) + self.error_code = error_code + + def __repr__(self) -> str: + return f"" diff --git a/dubbo/remoting/aio/http2/headers.py b/dubbo/remoting/aio/http2/headers.py new file mode 100644 index 0000000..293248f --- /dev/null +++ b/dubbo/remoting/aio/http2/headers.py @@ -0,0 +1,195 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum +from collections import OrderedDict +from typing import List, Optional, Tuple, Union + + +class PseudoHeaderName(enum.Enum): + """ + Pseudo-header names defined in RFC 7540 Section. + See: https://datatracker.ietf.org/doc/html/rfc7540#section-8.1.2 + """ + + SCHEME = ":scheme" + # Request pseudo-headers + METHOD = ":method" + AUTHORITY = ":authority" + PATH = ":path" + # Response pseudo-headers + STATUS = ":status" + + +class MethodType(enum.Enum): + """ + HTTP/2 method types. + """ + + GET = "GET" + POST = "POST" + PUT = "PUT" + DELETE = "DELETE" + HEAD = "HEAD" + OPTIONS = "OPTIONS" + PATCH = "PATCH" + TRACE = "TRACE" + CONNECT = "CONNECT" + + +class Http2Headers: + """ + HTTP/2 headers. + """ + + def __init__(self): + self._headers: OrderedDict[str, Optional[str]] = OrderedDict() + self._init() + + def _init(self): + # keep the order of headers + self._headers[PseudoHeaderName.SCHEME.value] = None + self._headers[PseudoHeaderName.METHOD.value] = None + self._headers[PseudoHeaderName.AUTHORITY.value] = None + self._headers[PseudoHeaderName.PATH.value] = None + self._headers[PseudoHeaderName.STATUS.value] = None + + def add(self, name: str, value: str) -> None: + """ + Add a header. + Args: + name: The header name. + value: The header value. + """ + self._headers[name] = value + + def get(self, name: str) -> Optional[str]: + """ + Get the header value. + Returns: + The header value: If the header exists, return the value. Otherwise, return None. + """ + return self._headers.get(name, None) + + @property + def method(self) -> Optional[str]: + """ + Get the method. + """ + return self.get(PseudoHeaderName.METHOD.value) + + @method.setter + def method(self, value: Union[MethodType, str]) -> None: + """ + Set the method. + Args: + value: The method value. + """ + if isinstance(value, MethodType): + value = value.value + else: + value = value.upper() + self.add(PseudoHeaderName.METHOD.value, value) + + @property + def scheme(self) -> Optional[str]: + """ + Get the scheme. + """ + return self.get(PseudoHeaderName.SCHEME.value) + + @scheme.setter + def scheme(self, value: str) -> None: + """ + Set the scheme. + Args: + value: The scheme value. + """ + self.add(PseudoHeaderName.SCHEME.value, value) + + @property + def authority(self) -> Optional[str]: + """ + Get the authority. + """ + return self.get(PseudoHeaderName.AUTHORITY.value) + + @authority.setter + def authority(self, value: str) -> None: + """ + Set the authority. + Args: + value: The authority value. + """ + self.add(PseudoHeaderName.AUTHORITY.value, value) + + @property + def path(self) -> Optional[str]: + """ + Get the path. + """ + return self.get(PseudoHeaderName.PATH.value) + + @path.setter + def path(self, value: str) -> None: + """ + Set the path. + Args: + value: The path value. + """ + self.add(PseudoHeaderName.PATH.value, value) + + @property + def status(self) -> Optional[str]: + """ + Get the status code. + """ + return self.get(PseudoHeaderName.STATUS.value) + + @status.setter + def status(self, value: str) -> None: + """ + Set the status code. + Args: + value: The status code. + """ + self.add(PseudoHeaderName.STATUS.value, value) + + def to_list(self) -> List[Tuple[str, str]]: + """ + Convert the headers to a list. The list contains all non-None headers. + Returns: + The headers list. + """ + return [ + (name, value) for name, value in self._headers.items() if value is not None + ] + + def __repr__(self) -> str: + return f"" + + @classmethod + def from_list(cls, headers: List[Tuple[str, str]]) -> "Http2Headers": + """ + Create an Http2Headers object from a list. + Args: + headers: The headers list. + Returns: + The Http2Headers object. + """ + http2_headers = cls() + for name, value in headers: + http2_headers.add(name, value) + return http2_headers diff --git a/dubbo/remoting/aio/http2/protocol.py b/dubbo/remoting/aio/http2/protocol.py new file mode 100644 index 0000000..e42bb9b --- /dev/null +++ b/dubbo/remoting/aio/http2/protocol.py @@ -0,0 +1,213 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import concurrent +from typing import List, Optional, Tuple, Union + +from h2.config import H2Configuration +from h2.connection import H2Connection + +from dubbo.constants import common_constants +from dubbo.logger.logger_factory import loggerFactory +from dubbo.remoting.aio.exceptions import ProtocolException +from dubbo.remoting.aio.http2.controllers import FollowController +from dubbo.remoting.aio.http2.frames import Http2Frame +from dubbo.remoting.aio.http2.registries import Http2FrameType +from dubbo.remoting.aio.http2.stream import Http2Stream +from dubbo.remoting.aio.http2.utils import Http2EventUtils +from dubbo.url import URL + +logger = loggerFactory.get_logger(__name__) + + +class Http2Protocol(asyncio.Protocol): + + def __init__(self, url: URL): + self._url = url + self._loop = asyncio.get_running_loop() + + # Create the H2 state machine + side_client = ( + self._url.get_parameter(common_constants.TRANSPORTER_SIDE_KEY) + == common_constants.TRANSPORTER_SIDE_CLIENT + ) + h2_config = H2Configuration(client_side=side_client, header_encoding="utf-8") + self._h2_connection: H2Connection = H2Connection(config=h2_config) + + # The transport instance + self._transport: Optional[asyncio.Transport] = None + + self._follow_controller: Optional[FollowController] = None + + self._stream_handler = self._url.attributes[ + common_constants.TRANSPORTER_STREAM_HANDLER_KEY + ] + + def connection_made(self, transport: asyncio.Transport): + """ + Called when the connection is first established. We complete the following actions: + 1. Save the transport. + 2. Initialize the H2 connection. + 3. Create and start the follow controller. + 4. Initialize the stream handler. + """ + self._transport = transport + self._h2_connection.initiate_connection() + self._transport.write(self._h2_connection.data_to_send()) + + # Create and start the follow controller + self._follow_controller = FollowController( + self._loop, self._h2_connection, self._transport + ) + self._follow_controller.start() + + # Initialize the stream handler + self._stream_handler.do_init(self._loop, self) + + # Notify the connection is established + if event := self._url.attributes.get("connect-event"): + event.set() + + def get_next_stream_id( + self, future: Union[asyncio.Future, concurrent.futures.Future] + ) -> None: + """ + Create a new stream.(thread-safe) + Args: + future: The future to set the stream identifier. + """ + + def _inner_operation(_future: Union[asyncio.Future, concurrent.futures.Future]): + stream_id = self._h2_connection.get_next_available_stream_id() + _future.set_result(stream_id) + + self._loop.call_soon_threadsafe(_inner_operation, future) + + def write(self, frame: Http2Frame, stream: Http2Stream) -> asyncio.Event: + """ + Send the HTTP/2 frame.(thread-safe) + Args: + frame: The HTTP/2 frame. + stream: The HTTP/2 stream. + Returns: + The event to be set after sending the frame. + """ + event = asyncio.Event() + self._loop.call_soon_threadsafe(self._send_frame, frame, stream, event) + return event + + def _send_frame(self, frame: Http2Frame, stream: Http2Stream, event: asyncio.Event): + """ + Send the HTTP/2 frame.(thread-unsafe) + Args: + frame: The HTTP/2 frame. + stream: The HTTP/2 stream. + event: The event to be set after sending the frame. + """ + frame_type = frame.frame_type + if frame_type == Http2FrameType.HEADERS: + self._send_headers_frame( + frame.stream_id, frame.headers.to_list(), frame.end_stream, event + ) + elif frame_type == Http2FrameType.DATA: + self._follow_controller.send_data( + stream, frame.data, frame.end_stream, event + ) + elif frame_type == Http2FrameType.RST_STREAM: + self._send_reset_frame(frame.stream_id, frame.error_code.value, event) + else: + logger.warning(f"Unhandled frame: {frame}") + + def _send_headers_frame( + self, + stream_id: int, + headers: List[Tuple[str, str]], + end_stream: bool, + event: Optional[asyncio.Event] = None, + ): + """ + Send the HTTP/2 headers frame.(thread-unsafe) + Args: + stream_id: The stream identifier. + headers: The headers to send. + end_stream: Whether the stream is ended. + event: The event to be set after sending the frame. + """ + self._h2_connection.send_headers(stream_id, headers, end_stream=end_stream) + self._transport.write(self._h2_connection.data_to_send()) + if event: + event.set() + + def _send_reset_frame( + self, stream_id: int, error_code: int, event: Optional[asyncio.Event] = None + ): + """ + Send the HTTP/2 reset frame.(thread-unsafe) + Args: + stream_id: The stream identifier. + error_code: The error code. + event: The event to be set after sending the frame + """ + self._h2_connection.reset_stream(stream_id, error_code) + self._transport.write(self._h2_connection.data_to_send()) + if event: + event.set() + + def data_received(self, data): + events = self._h2_connection.receive_data(data) + # Process the event + try: + for event in events: + frame = Http2EventUtils.convert_to_frame(event) + if frame is not None: + if frame.frame_type == Http2FrameType.WINDOW_UPDATE: + # Because flow control may be at the connection level, it is handled here + self._follow_controller.increment_flow_control_window( + frame.stream_id + ) + else: + self._stream_handler.handle_frame(frame) + + # If frame is None, there are two possible cases: + # 1. Events that are handled automatically by the H2 library (e.g. RemoteSettingsChanged, PingReceived). + # -> We just need to send it. + # 2. Events that are not implemented or do not require attention. -> We'll ignore it for now. + if outbound_data := self._h2_connection.data_to_send(): + self._transport.write(outbound_data) + + except Exception as e: + raise ProtocolException("Failed to process the Http/2 event.") from e + + def close(self): + """ + Close the connection. + """ + self._h2_connection.close_connection() + self._transport.write(self._h2_connection.data_to_send()) + + self._transport.close() + + def connection_lost(self, exc): + """ + Called when the connection is lost. + """ + self._follow_controller.stop() + # Notify the connection is established + if future := self._url.attributes.get("close-future"): + if exc: + future.set_exception(exc) + else: + future.set_result(None) diff --git a/dubbo/remoting/aio/http2/registries.py b/dubbo/remoting/aio/http2/registries.py new file mode 100644 index 0000000..69ac023 --- /dev/null +++ b/dubbo/remoting/aio/http2/registries.py @@ -0,0 +1,289 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum +from typing import Optional + + +class Http2FrameType(enum.Enum): + """ + Frame types are used in the frame header to identify the type of the frame. + See: https://datatracker.ietf.org/doc/html/rfc7540#section-11.2 + """ + + # Data frame, carries HTTP message bodies. + DATA = 0x0 + + # Headers frame, carries HTTP headers. + HEADERS = 0x1 + + # Priority frame, specifies the priority of a stream. + PRIORITY = 0x2 + + # Reset Stream frame, cancels a stream. + RST_STREAM = 0x3 + + # Settings frame, exchanges configuration parameters. + SETTINGS = 0x4 + + # Push Promise frame, used by the server to push resources. + PUSH_PROMISE = 0x5 + + # Ping frame, measures round-trip time and checks connectivity. + PING = 0x6 + + # Goaway frame, signals that the connection will be closed. + GOAWAY = 0x7 + + # Window Update frame, manages flow control window size. + WINDOW_UPDATE = 0x8 + + # Continuation frame, transmits large header blocks. + CONTINUATION = 0x9 + + +class Http2ErrorCode(enum.Enum): + """ + Error codes are 32-bit fields that are used in RST_STREAM and GOAWAY frames to convey the reasons for the stream or connection error. + + see: https://datatracker.ietf.org/doc/html/rfc7540#section-11.4 + """ + + # The associated condition is not a result of an error. + NO_ERROR = 0x0 + + # The endpoint detected an unspecific protocol error. + PROTOCOL_ERROR = 0x1 + + # The endpoint encountered an unexpected internal error. + INTERNAL_ERROR = 0x2 + + # The endpoint detected that its peer violated the flow-control protocol. + FLOW_CONTROL_ERROR = 0x3 + + # The endpoint sent a SETTINGS frame but did not receive a response in a timely manner. + SETTINGS_TIMEOUT = 0x4 + + # The endpoint received a frame after a stream was half-closed. + STREAM_CLOSED = 0x5 + + # The endpoint received a frame with an invalid size. + FRAME_SIZE_ERROR = 0x6 + + # The endpoint refused the stream prior to performing any application processing + REFUSED_STREAM = 0x7 + + # Used by the endpoint to indicate that the stream is no longer needed. + CANCEL = 0x8 + + # The endpoint is unable to maintain the header compression context for the connection. + COMPRESSION_ERROR = 0x9 + + # The connection established in response to a CONNECT request (Section 8.3) was reset or abnormally closed. + CONNECT_ERROR = 0xA + + # The endpoint detected that its peer is exhibiting a behavior that might be generating excessive load. + ENHANCE_YOUR_CALM = 0xB + + # The underlying transport has properties that do not meet minimum security requirements (see Section 9.2). + INADEQUATE_SECURITY = 0xC + + # The endpoint requires that HTTP/1.1 be used instead of HTTP/2. + HTTP_1_1_REQUIRED = 0xD + + @classmethod + def get(cls, code: int): + """ + Get the error code by code. + Args: + code: The error code. + Returns: + The error code. + """ + for error_code in cls: + if error_code.value == code: + return error_code + # Unknown or unsupported error codes MUST NOT trigger any special behavior. + # These MAY be treated as equivalent to INTERNAL_ERROR. + return cls.INTERNAL_ERROR + + +class Http2Settings: + """ + The settings are used to communicate configuration parameters that affect how endpoints communicate. + See: https://datatracker.ietf.org/doc/html/rfc7540#section-11.3 + """ + + class Http2Setting: + """ + HTTP/2 setting. + """ + + def __init__(self, code: int, initial_value: Optional[int] = None): + self.code = code + # If the initial value is "none", it means no limitation. + self.initial_value = initial_value + + # Allows the sender to inform the remote endpoint of the maximum size of the header compression table used to decode header blocks, in octets. + HEADER_TABLE_SIZE = Http2Setting(0x1, 4096) + + # This setting can be used to disable server push (Section 8.2). + ENABLE_PUSH = Http2Setting(0x2, 1) + + # Indicates the maximum number of concurrent streams that the sender will allow. + MAX_CONCURRENT_STREAMS = Http2Setting(0x3, None) + + # Indicates the sender's initial window size (in octets) for stream-level flow control. + # This setting affects the window size of all streams + INITIAL_WINDOW_SIZE = Http2Setting(0x4, 65535) + + # Indicates the size of the largest frame payload that the sender is willing to receive, in octets. + MAX_FRAME_SIZE = Http2Setting(0x5, 16384) + + # This advisory setting informs a peer of the maximum size of header list that the sender is prepared to accept, in octets. + MAX_HEADER_LIST_SIZE = Http2Setting(0x6, None) + + +class HttpStatus(enum.Enum): + """ + Enum for HTTP status codes as defined in RFC 7231 and related specifications. + """ + + # 1xx Informational + CONTINUE = 100 + SWITCHING_PROTOCOLS = 101 + + # 2xx Success + OK = 200 + CREATED = 201 + ACCEPTED = 202 + NON_AUTHORITATIVE_INFORMATION = 203 + NO_CONTENT = 204 + RESET_CONTENT = 205 + PARTIAL_CONTENT = 206 + + # 3xx Redirection + MULTIPLE_CHOICES = 300 + MOVED_PERMANENTLY = 301 + FOUND = 302 + SEE_OTHER = 303 + NOT_MODIFIED = 304 + USE_PROXY = 305 + TEMPORARY_REDIRECT = 307 + PERMANENT_REDIRECT = 308 + + # 4xx Client Error + BAD_REQUEST = 400 + UNAUTHORIZED = 401 + PAYMENT_REQUIRED = 402 + FORBIDDEN = 403 + NOT_FOUND = 404 + METHOD_NOT_ALLOWED = 405 + NOT_ACCEPTABLE = 406 + PROXY_AUTHENTICATION_REQUIRED = 407 + REQUEST_TIMEOUT = 408 + CONFLICT = 409 + GONE = 410 + LENGTH_REQUIRED = 411 + PRECONDITION_FAILED = 412 + PAYLOAD_TOO_LARGE = 413 + URI_TOO_LONG = 414 + UNSUPPORTED_MEDIA_TYPE = 415 + RANGE_NOT_SATISFIABLE = 416 + EXPECTATION_FAILED = 417 + I_AM_A_TEAPOT = 418 + MISDIRECTED_REQUEST = 421 + UNPROCESSABLE_ENTITY = 422 + LOCKED = 423 + FAILED_DEPENDENCY = 424 + UPGRADE_REQUIRED = 426 + PRECONDITION_REQUIRED = 428 + TOO_MANY_REQUESTS = 429 + REQUEST_HEADER_FIELDS_TOO_LARGE = 431 + UNAVAILABLE_FOR_LEGAL_REASONS = 451 + + # 5xx Server Error + INTERNAL_SERVER_ERROR = 500 + NOT_IMPLEMENTED = 501 + BAD_GATEWAY = 502 + SERVICE_UNAVAILABLE = 503 + GATEWAY_TIMEOUT = 504 + HTTP_VERSION_NOT_SUPPORTED = 505 + VARIANT_ALSO_NEGOTIATES = 506 + INSUFFICIENT_STORAGE = 507 + LOOP_DETECTED = 508 + NOT_EXTENDED = 510 + NETWORK_AUTHENTICATION_REQUIRED = 511 + + @classmethod + def from_code(cls, code: int) -> "HttpStatus": + for status in cls: + if status.value == code: + return status + + @staticmethod + def is_1xx(status): + """ + Check if the given status is an informational (1xx) status code. + Args: + status: HttpStatus to check + Returns: + True if the status code is in the 1xx range, False otherwise + """ + return 100 <= status.value < 200 + + @staticmethod + def is_2xx(status): + """ + Check if the given status is a successful (2xx) status code. + Args: + status: HttpStatus to check + Returns: + True if the status code is in the 2xx range, False otherwise + """ + return 200 <= status.value < 300 + + @staticmethod + def is_3xx(status): + """ + Check if the given status is a redirection (3xx) status code. + Args: + status: HttpStatus to check + Returns: + True if the status code is in the 3xx range, False otherwise + """ + return 300 <= status.value < 400 + + @staticmethod + def is_4xx(status): + """ + Check if the given status is a client error (4xx) status code. + Args: + status: HttpStatus to check + Returns: + True if the status code is in the 4xx range, False otherwise + """ + return 400 <= status.value < 500 + + @staticmethod + def is_5xx(status): + """ + Check if the given status is a server error (5xx) status code. + Args: + status: HttpStatus to check + Returns: + True if the status code is in the 5xx range, False otherwise + """ + return 500 <= status.value < 600 diff --git a/dubbo/remoting/aio/http2/stream.py b/dubbo/remoting/aio/http2/stream.py new file mode 100644 index 0000000..da6ee4a --- /dev/null +++ b/dubbo/remoting/aio/http2/stream.py @@ -0,0 +1,278 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from typing import Optional + +from dubbo.remoting.aio.exceptions import StreamException +from dubbo.remoting.aio.http2.frames import ( + DataFrame, + HeadersFrame, + Http2Frame, + ResetStreamFrame, +) +from dubbo.remoting.aio.http2.headers import Http2Headers +from dubbo.remoting.aio.http2.registries import Http2ErrorCode, Http2FrameType + + +class Http2Stream: + """ + A "stream" is an independent, bidirectional sequence of frames exchanged between the client and server within an HTTP/2 connection. + see: https://datatracker.ietf.org/doc/html/rfc7540#section-5 + Args: + stream_id: The stream identifier. + listener: The stream listener. + loop: The asyncio event loop. + protocol: The HTTP/2 protocol. + """ + + def __init__( + self, + stream_id: int, + listener: "StreamListener", + loop: asyncio.AbstractEventLoop, + protocol, + ): + from dubbo.remoting.aio.http2.controllers import FrameOrderController + from dubbo.remoting.aio.http2.protocol import Http2Protocol + + self._loop: asyncio.AbstractEventLoop = loop + self._protocol: Http2Protocol = protocol + + # The stream identifier. + self._id = stream_id + + self._listener = listener + + # The frame order controller. + self._frame_order_controller: FrameOrderController = FrameOrderController( + self, self._loop, self._protocol + ) + self._frame_order_controller.start() + + # Whether the headers have been sent. + self._headers_sent = False + # Whether the headers have been received. + self._headers_received = False + + # Indicates whether the frame identified with end_stream was written (and may not have been sent yet). + self._end_stream = False + + # Whether the stream is closed locally or remotely. + self._local_closed = False + self._remote_closed = False + + @property + def id(self) -> int: + return self._id + + def is_headers_sent(self) -> bool: + return self._headers_sent + + def is_local_closed(self) -> bool: + """ + Check if the stream is closed locally. + """ + return self._local_closed + + def close_local(self) -> None: + """ + Close the stream locally. + """ + self._local_closed = True + self._frame_order_controller.stop() + + def is_remote_closed(self) -> bool: + """ + Check if the stream is closed remotely. + """ + return self._remote_closed + + def close_remote(self) -> None: + """ + Close the stream remotely. + """ + self._remote_closed = True + + def _send_available(self): + """ + Check if the stream is available for sending frames. + """ + return not self.is_local_closed() and not self._end_stream + + def send_headers(self, headers: Http2Headers, end_stream: bool = False) -> None: + """ + Send the headers.(thread-unsafe) + Args: + headers: The HTTP/2 headers. + end_stream: Whether to close the stream after sending the data. + """ + if self.is_headers_sent(): + raise StreamException("Headers have been sent.") + elif not self._send_available(): + raise StreamException( + "The stream cannot send a frame because it has been closed." + ) + + headers_frame = HeadersFrame(self.id, headers, end_stream=end_stream) + self._end_stream = end_stream + self._frame_order_controller.write_headers(headers_frame) + + self._headers_sent = True + + def send_data( + self, data: bytes, end_stream: bool = False, last: bool = False + ) -> None: + """ + Send the data.(thread-unsafe) + Args: + data: The data to send. + end_stream: Whether to close the stream after sending the data. + last: Is it the last data frame? + """ + if not self.is_headers_sent(): + raise StreamException("Headers have not been sent.") + elif not self._send_available(): + raise StreamException( + "The stream cannot send a frame because it has been closed." + ) + + data_frame = DataFrame(self.id, data, len(data), end_stream=end_stream) + self._end_stream = end_stream + self._frame_order_controller.write_data(data_frame, last) + + def send_trailers(self, headers: Http2Headers, send_data: bool) -> None: + """ + Send trailers with the given headers. Optionally, indicate if data frames + need to be sent. + + Args: + headers: The HTTP/2 headers to be sent as trailers. + send_data: A flag indicating whether data frames need to be sent. + """ + if not self.is_headers_sent(): + raise StreamException("Headers have not been sent.") + elif not self._send_available(): + raise StreamException( + "The stream cannot send a frame because it has been closed." + ) + + trailers_frame = HeadersFrame(self.id, headers, end_stream=True) + self._end_stream = True + if send_data: + self._frame_order_controller.write_trailers_after_data(trailers_frame) + else: + self._frame_order_controller.write_trailers(trailers_frame) + + def send_reset(self, error_code: Http2ErrorCode) -> None: + """ + Send the reset frame.(thread-unsafe) + Args: + error_code: The error code. + """ + if self.is_local_closed(): + raise StreamException("The stream has been reset.") + + reset_frame = ResetStreamFrame(self.id, error_code) + # It's a special frame, no need to queue, just send it + self._protocol.write(reset_frame, self) + # close the stream locally and remotely + self.close_local() + self.close_remote() + + def receive_frame(self, frame: Http2Frame) -> None: + """ + Receive a frame from the stream. + Args: + frame: The frame to be received. + """ + if self.is_remote_closed(): + # The stream is closed remotely, ignore the frame + return + + if frame.end_stream: + # received end_stream frame, close the stream remotely + self.close_remote() + + frame_type = frame.frame_type + if frame_type == Http2FrameType.HEADERS: + if not self._headers_received: + # HEADERS frame + self._headers_received = True + self._listener.on_headers(frame.headers, frame.end_stream) + else: + # TRAILERS frame + self._listener.on_trailers(frame.headers) + elif frame_type == Http2FrameType.DATA: + self._listener.on_data(frame.data, frame.end_stream) + elif frame_type == Http2FrameType.RST_STREAM: + self._listener.on_reset(frame.error_code) + self.close_local() + + +class StreamListener: + """ + Http2StreamListener is a base class for handling events in an HTTP/2 stream. + + This class provides a set of callback methods that are called when specific + events occur on the stream, such as receiving headers, receiving data, or + resetting the stream. To use this class, create a subclass and implement the + callback methods for the events you want to handle. + """ + + def __init__(self): + self._stream: Optional[Http2Stream] = None + + def bind(self, stream: Http2Stream) -> None: + """ + Bind the stream to the listener. + Args: + stream: The stream. + """ + self._stream = stream + + def on_headers(self, headers: Http2Headers, end_stream: bool) -> None: + """ + Called when the headers are received. + Args: + headers: The HTTP/2 headers. + end_stream: Whether the stream is closed after receiving the headers. + """ + raise NotImplementedError("on_headers() is not implemented.") + + def on_data(self, data: bytes, end_stream: bool) -> None: + """ + Called when the data is received. + Args: + data: The data. + end_stream: Whether the stream is closed after receiving the data. + """ + raise NotImplementedError("on_data() is not implemented.") + + def on_trailers(self, headers: Http2Headers) -> None: + """ + Called when the trailers are received. + Args: + headers: The HTTP/2 headers. + """ + raise NotImplementedError("on_trailers() is not implemented.") + + def on_reset(self, error_code: Http2ErrorCode) -> None: + """ + Called when the stream is reset. + Args: + error_code: The error code. + """ + raise NotImplementedError("on_reset() is not implemented.") diff --git a/dubbo/remoting/aio/http2/stream_handler.py b/dubbo/remoting/aio/http2/stream_handler.py new file mode 100644 index 0000000..b6e7a3e --- /dev/null +++ b/dubbo/remoting/aio/http2/stream_handler.py @@ -0,0 +1,169 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from concurrent import futures +from typing import Dict, Optional + +from dubbo.logger.logger_factory import loggerFactory +from dubbo.remoting.aio.exceptions import ProtocolException +from dubbo.remoting.aio.http2.frames import Http2Frame +from dubbo.remoting.aio.http2.registries import Http2FrameType +from dubbo.remoting.aio.http2.stream import Http2Stream, StreamListener + +logger = loggerFactory.get_logger(__name__) + + +class StreamMultiplexHandler: + """ + The StreamMultiplexHandler class is responsible for managing the HTTP/2 streams. + """ + + def __init__(self, executor: Optional[futures.ThreadPoolExecutor] = None): + # Import the Http2Protocol class here to avoid circular imports. + from dubbo.remoting.aio.http2.protocol import Http2Protocol + + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._protocol: Optional[Http2Protocol] = None + + # The map of stream_id to stream. + self._streams: Optional[Dict[int, Http2Stream]] = None + + # The executor for handling received frames. + self._executor = executor + + def do_init(self, loop: asyncio.AbstractEventLoop, protocol) -> None: + """ + Initialize the StreamMultiplexHandler.\ + Args: + loop: The asyncio event loop. + protocol: The HTTP/2 protocol. + """ + self._loop = loop + self._protocol = protocol + self._streams = {} + + def put_stream(self, stream_id: int, stream: Http2Stream) -> None: + """ + Put the stream into the stream map. + Args: + stream_id: The stream identifier. + stream: The stream. + """ + self._streams[stream_id] = stream + + def get_stream(self, stream_id: int) -> Optional[Http2Stream]: + """ + Get the stream by stream identifier. + Args: + stream_id: The stream identifier. + Returns: + The stream. + """ + return self._streams.get(stream_id) + + def remove_stream(self, stream_id: int) -> None: + """ + Remove the stream by stream identifier. + Args: + stream_id: The stream identifier. + """ + self._streams.pop(stream_id, None) + + def handle_frame(self, frame: Http2Frame) -> None: + """ + Handle the HTTP/2 frame. + Args: + frame: The HTTP/2 frame. + """ + if stream := self._streams.get(frame.stream_id): + # Handle the frame in the executor. + self._handle_frame_in_executor(stream, frame) + else: + logger.warning( + f"Stream {frame.stream_id} not found. Ignoring frame {frame}" + ) + + def _handle_frame_in_executor(self, stream: Http2Stream, frame: Http2Frame) -> None: + """ + Handle the HTTP/2 frame in the executor. + Args: + frame: The HTTP/2 frame. + """ + self._loop.run_in_executor(self._executor, stream.receive_frame, frame) + + def destroy(self) -> None: + """ + Destroy the StreamMultiplexHandler. + """ + self._streams = None + self._protocol = None + self._loop = None + + +class StreamClientMultiplexHandler(StreamMultiplexHandler): + """ + The StreamClientMultiplexHandler class is responsible for managing the HTTP/2 streams on the client side. + """ + + def create(self, listener: StreamListener) -> Http2Stream: + """ + Create a new stream. + Returns: + The created stream. + """ + future = futures.Future() + self._protocol.get_next_stream_id(future) + try: + # block until the stream_id is created + stream_id = future.result() + self._streams[stream_id] = Http2Stream( + stream_id, listener, self._loop, self._protocol + ) + except Exception as e: + raise ProtocolException("Failed to create stream.") from e + + return self._streams[stream_id] + + +class StreamServerMultiplexHandler(StreamMultiplexHandler): + """ + The StreamServerMultiplexHandler class is responsible for managing the HTTP/2 streams on the server side. + """ + + def register(self, stream_id: int) -> Http2Stream: + """ + Register the stream. + Args: + stream_id: The stream identifier. + Returns: + The created stream. + """ + stream = Http2Stream(stream_id, StreamListener(), self._loop, self._protocol) + self._streams[stream_id] = stream + return stream + + def handle_frame(self, frame: Http2Frame) -> None: + """ + Handle the HTTP/2 frame. + Args: + frame: The HTTP/2 frame. + """ + # Register the stream if the frame is a HEADERS frame. + if frame.frame_type == Http2FrameType.HEADERS: + self.register(frame.stream_id) + + # Handle the frame. + super().handle_frame(frame) diff --git a/dubbo/remoting/aio/http2/utils.py b/dubbo/remoting/aio/http2/utils.py new file mode 100644 index 0000000..8ecb18f --- /dev/null +++ b/dubbo/remoting/aio/http2/utils.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import h2.events as h2_event + +from dubbo.remoting.aio.http2.frames import ( + DataFrame, + HeadersFrame, + Http2Frame, + ResetStreamFrame, + WindowUpdateFrame, +) +from dubbo.remoting.aio.http2.headers import Http2Headers +from dubbo.remoting.aio.http2.registries import Http2ErrorCode + + +class Http2EventUtils: + """ + A utility class for converting H2 events to HTTP/2 frames. + """ + + @staticmethod + def convert_to_frame(event: h2_event.Event) -> Optional[Http2Frame]: + """ + Convert a h2.events.Event to HTTP/2 Frame. + Args: + event: The H2 event to convert. + Returns: + The converted HTTP/2 Frame. If the event is not supported, return None. + """ + if isinstance( + event, + ( + h2_event.RequestReceived, + h2_event.ResponseReceived, + h2_event.TrailersReceived, + ), + ): + # HEADERS frame. + return HeadersFrame( + event.stream_id, + Http2Headers.from_list(event.headers), + end_stream=event.stream_ended is not None, + ) + elif isinstance(event, h2_event.DataReceived): + # DATA frame. + return DataFrame( + event.stream_id, + event.data, + event.flow_controlled_length, + end_stream=event.stream_ended is not None, + ) + elif isinstance(event, h2_event.StreamReset): + # RST_STREAM frame. + return ResetStreamFrame( + event.stream_id, Http2ErrorCode.get(event.error_code) + ) + elif isinstance(event, h2_event.WindowUpdated): + # WINDOW_UPDATE frame. + return WindowUpdateFrame(event.stream_id, event.delta) + else: + return None diff --git a/dubbo/remoting/aio/loop.py b/dubbo/remoting/aio/loop.py deleted file mode 100644 index 503432e..0000000 --- a/dubbo/remoting/aio/loop.py +++ /dev/null @@ -1,150 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import threading -from typing import Optional, Tuple - -from dubbo.logger.logger_factory import loggerFactory - -logger = loggerFactory.get_logger(__name__) - - -def start_loop(running_loop: asyncio.AbstractEventLoop) -> None: - """ - Start the running_loop. - Args: - running_loop: The running_loop to start. - """ - asyncio.set_event_loop(running_loop) - running_loop.run_forever() - - -async def _stop_loop( - running_loop: Optional[asyncio.AbstractEventLoop] = None, - signal: Optional[threading.Event] = None, -) -> None: - """ - Real function to stop the running_loop. - Args: - running_loop: The running_loop to stop. If None, the current running_loop will be stopped. - signal: The future to set the result. - """ - running_loop = running_loop or asyncio.get_running_loop() - # Cancel all tasks - tasks = [ - task for task in asyncio.all_tasks(running_loop) if task is not asyncio.current_task() - ] - for task in tasks: - task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - # Stop the event running_loop - running_loop.stop() - if signal: - # Set the result of the future - signal.set() - - -def stop_loop(running_loop: Optional[asyncio.AbstractEventLoop] = None, wait: bool = False): - """ - Stop the running_loop. It will cancel all tasks and stop the running_loop.(thread-safe) - Args: - running_loop: The running_loop to stop. If None, the current running_loop will be stopped. - wait: Whether to wait for the running_loop to stop. - """ - running_loop = running_loop or asyncio.get_running_loop() - # Create a future to wait for the running_loop to stop - signal = threading.Event() - # Call the asynchronous function to stop the running_loop - asyncio.run_coroutine_threadsafe(_stop_loop(signal=signal), running_loop) - if wait: - # Wait for the running_loop to stop - signal.wait() - - -def start_loop_in_thread( - thread_name: str, running_loop: Optional[asyncio.AbstractEventLoop] = None -) -> Tuple[asyncio.AbstractEventLoop, threading.Thread]: - """ - start the asyncio event running_loop in a separate thread. - - Args: - thread_name: The name of the thread to run the event running_loop in. - running_loop: The event running_loop to run in the thread. If None, a new event running_loop will be created. - - Returns: - A tuple containing the new event running_loop and the thread it is running in. - """ - new_loop = running_loop or asyncio.new_event_loop() - # Start the running_loop in a new thread - thread = threading.Thread( - target=start_loop, args=(new_loop,), name=thread_name, daemon=True - ) - # Start the thread - thread.start() - return new_loop, thread - - -def stop_loop_in_thread( - running_loop: asyncio.AbstractEventLoop, thread: threading.Thread, wait: bool = False -) -> None: - """ - Stop the event running_loop running in a separate thread and close the thread. - - Args: - running_loop: The event running_loop to stop. - thread: The thread running the event running_loop. - wait: Whether to wait for all tasks to be cancelled and the running_loop to stop. - """ - stop_loop(running_loop, wait=wait) - # Wait for the thread to join - if wait: - print("等待线程结束") - thread.join() - - -def _try_use_uvloop() -> None: - """ - Use uvloop instead of the default asyncio running_loop. - """ - import asyncio - import os - - # Check if the operating system. - if os.name == "nt": - # Windows is not supported. - logger.warning( - "Unable to use uvloop, because it is not supported on your operating system." - ) - return - - # Try import uvloop. - try: - import uvloop - except ImportError: - # uvloop is not available. - logger.warning( - "Unable to use uvloop, because it is not installed. " - "You can install it by running `pip install uvloop`." - ) - return - - # Use uvloop instead of the default asyncio running_loop. - if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy): - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - - -# Call the function to try to use uvloop. -_try_use_uvloop() diff --git a/dubbo/remoting/transporter.py b/dubbo/remoting/transporter.py index ff68bf4..f56dc5f 100644 --- a/dubbo/remoting/transporter.py +++ b/dubbo/remoting/transporter.py @@ -20,30 +20,18 @@ class Client: def __init__(self, url: URL): self._url = url - # flag to indicate whether the client is opened - self._opened = False - # flag to indicate whether the client is connected - self._connected = False - # flag to indicate whether the client is closed - self._closed = False - @property - def opened(self): - return self._opened - - @property - def connected(self): - return self._connected - - @property - def closed(self): - return self._closed + def is_connected(self) -> bool: + """ + Check if the client is connected. + """ + raise NotImplementedError("is_connected() is not implemented.") - def open(self): + def is_closed(self) -> bool: """ - Open the client. + Check if the client is closed. """ - raise NotImplementedError("open() is not implemented.") + raise NotImplementedError("is_closed() is not implemented.") def connect(self): """ @@ -51,6 +39,12 @@ def connect(self): """ raise NotImplementedError("connect() is not implemented.") + def reconnect(self): + """ + Reconnect to the server. + """ + raise NotImplementedError("reconnect() is not implemented.") + def close(self): """ Close the client. diff --git a/dubbo/serialization.py b/dubbo/serialization.py index 3d92f27..0a5baa5 100644 --- a/dubbo/serialization.py +++ b/dubbo/serialization.py @@ -13,71 +13,75 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional -from dubbo.constants import common_constants +from dubbo.constants.type_constants import DeserializingFunction, SerializingFunction from dubbo.logger.logger_factory import loggerFactory -from dubbo.url import URL logger = loggerFactory.get_logger(__name__) -def serialize(method: str, url: URL, *args, **kwargs) -> bytes: +class Serialization: """ - Serialize the given data + Serialization class Args: - method(str): The method to serialize - url(URL): URL - *args: Variable length argument list - **kwargs: Arbitrary keyword arguments - Returns: - bytes: The serialized data - Exception: If the serialization fails + serializing_function(SerializingFunction): The serialization function + deserializing_function(DeserializingFunction): The deserialization function """ - # get the serializer - method_dict = url.get_attribute(method) or {} - serializer = method_dict.get(common_constants.SERIALIZATION) - # serialize the data - if serializer: - try: - return serializer(*args, **kwargs) - except Exception as e: - logger.exception( - "Serialization send error, please check the incoming serialization function" - ) - raise e - else: - # check if the data is bytes -> args[0] - if isinstance(args[0], bytes): - return args[0] - else: - err_msg = "The args[0] is not bytes, you should pass parameters of type bytes, or set the serialization function" - logger.error(err_msg) - raise ValueError(err_msg) + def __init__( + self, + serializing_function: Optional[SerializingFunction] = None, + deserializing_function: Optional[DeserializingFunction] = None, + ): + self.serializing_function = serializing_function + self.deserializing_function = deserializing_function -def deserialize(method: str, url: URL, data: bytes) -> Any: - """ - Deserialize the given data - Args: - method(str): The method to deserialize - url(URL): URL - data(bytes): The data to deserialize - Returns: - Any: The deserialized data - Exception: If the deserialization fails - """ - # get the deserializer - method_dict = url.get_attribute(method) or {} - deserializer = method_dict.get(common_constants.DESERIALIZATION) - # deserialize the data - if not deserializer: - return data - else: - try: - return deserializer(data) - except Exception as e: - logger.exception( - "Deserialization send error, please check the incoming deserialization function" - ) - raise e + def serialize(self, *args, **kwargs) -> bytes: + """ + Serialize the given data + Args: + *args: Variable length argument list + **kwargs: Arbitrary keyword arguments + Returns: + bytes: The serialized data + Exception: If the serialization fails + """ + # serialize the data + if self.serializing_function: + try: + return self.serializing_function(*args, **kwargs) + except Exception as e: + logger.exception( + "Serialization send error, please check the incoming serialization function" + ) + raise e + else: + # check if the data is bytes -> args[0] + if isinstance(args[0], bytes): + return args[0] + else: + err_msg = "The args[0] is not bytes, you should pass parameters of type bytes, or set the serialization function" + logger.error(err_msg) + raise ValueError(err_msg) + + def deserialize(self, data: bytes) -> Any: + """ + Deserialize the given data + Args: + data(bytes): The data to deserialize + Returns: + Any: The deserialized data + Exception: If the deserialization fails + """ + # deserialize the data + if not self.deserializing_function: + return data + else: + try: + return self.deserializing_function(data) + except Exception as e: + logger.exception( + "Deserialization send error, please check the incoming deserialization function" + ) + raise e diff --git a/dubbo/url.py b/dubbo/url.py index 0072164..2178457 100644 --- a/dubbo/url.py +++ b/dubbo/url.py @@ -38,11 +38,23 @@ class URL: - registry://192.168.1.7:9090/org.apache.dubbo.service1?param1=value1¶m2=value2 """ + __slots__ = [ + "_scheme", + "_host", + "_port", + "_location", + "_username", + "_password", + "_path", + "_parameters", + "_attributes", + ] + def __init__( self, scheme: str, host: str, - port: int = 0, + port: Optional[int] = None, username: str = "", password: str = "", path: str = "", @@ -53,7 +65,7 @@ def __init__( self._host = host self._port = port # location -> host:port - self._location = f"{host}:{port}" if port > 0 else host + self._location = f"{host}:{port}" if port else host self._username = username self._password = password self._path = path @@ -112,7 +124,7 @@ def host(self, host: str) -> None: self._location = f"{host}:{self.port}" if self.port else host @property - def port(self) -> int: + def port(self) -> Optional[int]: """ Gets the port of the URL. @@ -129,7 +141,7 @@ def port(self, port: int) -> None: Args: port (int): The port to set. """ - self._port = max(port, 0) + port = port if port > 0 else None self._location = f"{self.host}:{port}" if port else self.host @property @@ -192,26 +204,6 @@ def path(self, path: str) -> None: """ self._path = path - @property - def parameters(self) -> Dict[str, str]: - """ - Gets the query parameters of the URL. - - Returns: - Dict[str, str]: The query parameters of the URL. - """ - return self._parameters - - @parameters.setter - def parameters(self, parameters: Dict[str, str]) -> None: - """ - Sets the query parameters of the URL. - - Args: - parameters (Dict[str, str]): The query parameters to set. - """ - self._parameters = parameters - def get_parameter(self, key: str) -> Optional[str]: """ Gets a query parameter from the URL. @@ -243,25 +235,6 @@ def attributes(self): """ return self._attributes - def add_attribute(self, key: str, value: Any) -> None: - """ - ADDs an attribute to the URL. - Args: - key (str): The attribute name. - value (Any): The attribute value. - """ - self._attributes[key] = value - - def get_attribute(self, key: str) -> Optional[Any]: - """ - Gets an attribute from the URL. - Args: - key (str): The attribute name. - Returns: - Any: The attribute value. If the attribute does not exist, returns None. - """ - return self._attributes.get(key, None) - def build_string(self, encode: bool = False) -> str: """ Generates the URL string based on the current components. @@ -287,13 +260,29 @@ def build_string(self, encode: bool = False) -> str: if self.path: url += f"{self.path}" # Set params - if self.parameters: - url += "?" + "&".join([f"{k}={v}" for k, v in self.parameters.items()]) + if self._parameters: + url += "?" + "&".join([f"{k}={v}" for k, v in self._parameters.items()]) # If the URL needs to be encoded, encode it if encode: url = parse.quote(url) return url + def clone_without_attributes(self) -> "URL": + """ + Clones the URL object without the attributes. + Returns: + URL: The cloned URL object. + """ + return URL( + self.scheme, + self.host, + self.port, + self.username, + self.password, + self.path, + self._parameters.copy(), + ) + def clone(self) -> "URL": """ Clones the URL object. Ignores the attributes. @@ -308,7 +297,8 @@ def clone(self) -> "URL": self.username, self.password, self.path, - copy.deepcopy(self.parameters), + self._parameters.copy(), + copy.deepcopy(self._attributes), ) def __str__(self) -> str: @@ -346,7 +336,7 @@ def value_of(cls, url: str, encoded: bool = False) -> "URL": protocol = parsed_url.scheme host = parsed_url.hostname or "" - port = parsed_url.port or 0 + port = parsed_url.port or None username = parsed_url.username or "" password = parsed_url.password or "" parameters = {k: v[0] for k, v in parse.parse_qs(parsed_url.query).items()} diff --git a/tests/common/tets_url.py b/tests/common/tets_url.py index fa4c72d..912c939 100644 --- a/tests/common/tets_url.py +++ b/tests/common/tets_url.py @@ -26,7 +26,7 @@ def test_str_to_url(self): ) self.assertEqual("http", url_0.scheme) self.assertEqual("www.facebook.com", url_0.host) - self.assertEqual(0, url_0.port) + self.assertEqual(None, url_0.port) self.assertEqual("friends", url_0.path) self.assertEqual("value1", url_0.get_parameter("param1")) self.assertEqual("value2", url_0.get_parameter("param2")) @@ -50,7 +50,7 @@ def test_str_to_url(self): ) self.assertEqual("http", url_3.scheme) self.assertEqual("www.facebook.com", url_3.host) - self.assertEqual(0, url_3.port) + self.assertEqual(None, url_3.port) self.assertEqual("friends", url_3.path) self.assertEqual("value1", url_3.get_parameter("param1")) self.assertEqual("value2", url_3.get_parameter("param2")) From 7608afe3f1887b725806a68334942c40473ce0a2 Mon Sep 17 00:00:00 2001 From: zaki Date: Sun, 4 Aug 2024 14:05:38 +0800 Subject: [PATCH 29/32] feat: Refactored and refined rpc calling capabilities --- dubbo/__init__.py | 2 - dubbo/_dubbo.py | 176 ------ dubbo/{client => }/client.py | 63 ++- .../protocol.py => common/__init__.py} | 28 +- .../compression.py => common/classes.py} | 36 +- .../constants.py} | 58 +- dubbo/common/deliverers.py | 314 +++++++++++ dubbo/{ => common}/node.py | 34 +- .../type_constants.py => common/types.py} | 3 + dubbo/common/url.py | 325 ++++++++++++ dubbo/common/utils.py | 129 +++++ dubbo/compression/__init__.py | 22 + dubbo/compression/_interfaces.py | 69 +++ dubbo/compression/bzip2s.py | 56 ++ .../gzips.py} | 40 +- dubbo/compression/identities.py | 57 ++ dubbo/config/__init__.py | 2 +- dubbo/config/logger_config.py | 15 +- dubbo/config/reference_config.py | 8 +- dubbo/config/service_config.py | 71 +++ dubbo/extension/__init__.py | 1 + dubbo/extension/extension_loader.py | 111 ++-- .../extension/{registry.py => registries.py} | 73 +-- dubbo/{constants => loadbalance}/__init__.py | 2 + dubbo/loadbalance/_interfaces.py | 78 +++ dubbo/logger/__init__.py | 8 +- dubbo/logger/_interfaces.py | 204 +++++++ .../constants.py} | 53 +- dubbo/logger/logger.py | 175 ------ dubbo/logger/logger_factory.py | 136 ++--- dubbo/logger/logging/__init__.py | 2 + dubbo/logger/logging/formatter.py | 3 + dubbo/logger/logging/logger.py | 35 +- dubbo/logger/logging/logger_adapter.py | 99 ++-- dubbo/protocol/__init__.py | 4 + dubbo/protocol/_interfaces.py | 121 +++++ dubbo/protocol/invocation.py | 61 +-- dubbo/protocol/invoker.py | 35 -- dubbo/protocol/result.py | 67 --- dubbo/protocol/triple/call/__init__.py | 20 + dubbo/protocol/triple/call/_interfaces.py | 143 +++++ dubbo/protocol/triple/call/client_call.py | 178 +++++++ dubbo/protocol/triple/call/server_call.py | 268 ++++++++++ dubbo/protocol/triple/client/calls.py | 156 ------ .../protocol/triple/client/stream_listener.py | 108 ---- .../triple/{tri_codec.py => coders.py} | 146 +++-- .../triple/{tri_status.py => constants.py} | 54 +- .../{tri_constants.py => exceptions.py} | 36 +- dubbo/protocol/triple/invoker.py | 215 ++++++++ dubbo/protocol/triple/metadata.py | 95 ++++ dubbo/protocol/triple/protocol.py | 106 ++++ .../triple/{tri_results.py => results.py} | 75 ++- dubbo/protocol/triple/status.py | 152 ++++++ .../triple/stream}/__init__.py | 4 + dubbo/protocol/triple/stream/_interfaces.py | 167 ++++++ dubbo/protocol/triple/stream/client_stream.py | 312 +++++++++++ dubbo/protocol/triple/stream/server_stream.py | 325 ++++++++++++ dubbo/protocol/triple/tri_invoker.py | 140 ----- dubbo/protocol/triple/tri_protocol.py | 61 --- .../triple/client => proxy}/__init__.py | 4 + dubbo/proxy/_interfaces.py | 61 +++ dubbo/{callable.py => proxy/callables.py} | 42 +- dubbo/proxy/handlers.py | 136 +++++ dubbo/{compressor => registry}/__init__.py | 2 + dubbo/registry/_interfaces.py | 82 +++ .../registry/zookeeper/__init__.py | 15 +- dubbo/registry/zookeeper/_interfaces.py | 251 +++++++++ dubbo/registry/zookeeper/kazoo_transport.py | 427 +++++++++++++++ dubbo/registry/zookeeper/zk_registry.py | 88 +++ dubbo/remoting/__init__.py | 4 + .../{transporter.py => _interfaces.py} | 68 ++- dubbo/remoting/aio/aio_transporter.py | 166 ++++-- dubbo/remoting/aio/constants.py | 21 + dubbo/remoting/aio/event_loop.py | 9 +- dubbo/remoting/aio/exceptions.py | 10 +- dubbo/remoting/aio/http2/controllers.py | 500 ++++++++++-------- dubbo/remoting/aio/http2/frames.py | 37 +- dubbo/remoting/aio/http2/headers.py | 112 ++-- dubbo/remoting/aio/http2/protocol.py | 122 +++-- dubbo/remoting/aio/http2/registries.py | 3 + dubbo/remoting/aio/http2/stream.py | 376 +++++++------ dubbo/remoting/aio/http2/stream_handler.py | 98 ++-- dubbo/remoting/aio/http2/utils.py | 10 +- dubbo/serialization.py | 87 --- dubbo/serialization/__init__.py | 30 ++ dubbo/serialization/_interfaces.py | 91 ++++ dubbo/serialization/custom_serializers.py | 85 +++ dubbo/serialization/direct_serializers.py | 58 ++ .../{config/consumer_config.py => server.py} | 27 +- dubbo/url.py | 347 ------------ requirements.txt | 3 +- tests/common/tets_url.py | 24 +- tests/logger/__init__.py | 15 - tests/logger/test_logger_factory.py | 49 -- tests/logger/test_logging_logger.py | 50 -- 95 files changed, 6383 insertions(+), 2664 deletions(-) delete mode 100644 dubbo/_dubbo.py rename dubbo/{client => }/client.py (63%) rename dubbo/{protocol/protocol.py => common/__init__.py} (65%) rename dubbo/{compressor/compression.py => common/classes.py} (57%) rename dubbo/{constants/common_constants.py => common/constants.py} (53%) create mode 100644 dubbo/common/deliverers.py rename dubbo/{ => common}/node.py (64%) rename dubbo/{constants/type_constants.py => common/types.py} (93%) create mode 100644 dubbo/common/url.py create mode 100644 dubbo/common/utils.py create mode 100644 dubbo/compression/__init__.py create mode 100644 dubbo/compression/_interfaces.py create mode 100644 dubbo/compression/bzip2s.py rename dubbo/{compressor/gzip_compression.py => compression/gzips.py} (58%) create mode 100644 dubbo/compression/identities.py create mode 100644 dubbo/config/service_config.py rename dubbo/extension/{registry.py => registries.py} (54%) rename dubbo/{constants => loadbalance}/__init__.py (92%) create mode 100644 dubbo/loadbalance/_interfaces.py create mode 100644 dubbo/logger/_interfaces.py rename dubbo/{constants/logger_constants.py => logger/constants.py} (64%) delete mode 100644 dubbo/logger/logger.py create mode 100644 dubbo/protocol/_interfaces.py delete mode 100644 dubbo/protocol/invoker.py delete mode 100644 dubbo/protocol/result.py create mode 100644 dubbo/protocol/triple/call/__init__.py create mode 100644 dubbo/protocol/triple/call/_interfaces.py create mode 100644 dubbo/protocol/triple/call/client_call.py create mode 100644 dubbo/protocol/triple/call/server_call.py delete mode 100644 dubbo/protocol/triple/client/calls.py delete mode 100644 dubbo/protocol/triple/client/stream_listener.py rename dubbo/protocol/triple/{tri_codec.py => coders.py} (56%) rename dubbo/protocol/triple/{tri_status.py => constants.py} (75%) rename dubbo/protocol/triple/{tri_constants.py => exceptions.py} (57%) create mode 100644 dubbo/protocol/triple/invoker.py create mode 100644 dubbo/protocol/triple/metadata.py create mode 100644 dubbo/protocol/triple/protocol.py rename dubbo/protocol/triple/{tri_results.py => results.py} (51%) create mode 100644 dubbo/protocol/triple/status.py rename dubbo/{client => protocol/triple/stream}/__init__.py (88%) create mode 100644 dubbo/protocol/triple/stream/_interfaces.py create mode 100644 dubbo/protocol/triple/stream/client_stream.py create mode 100644 dubbo/protocol/triple/stream/server_stream.py delete mode 100644 dubbo/protocol/triple/tri_invoker.py delete mode 100644 dubbo/protocol/triple/tri_protocol.py rename dubbo/{protocol/triple/client => proxy}/__init__.py (87%) create mode 100644 dubbo/proxy/_interfaces.py rename dubbo/{callable.py => proxy/callables.py} (57%) create mode 100644 dubbo/proxy/handlers.py rename dubbo/{compressor => registry}/__init__.py (93%) create mode 100644 dubbo/registry/_interfaces.py rename tests/test_dubbo.py => dubbo/registry/zookeeper/__init__.py (85%) create mode 100644 dubbo/registry/zookeeper/_interfaces.py create mode 100644 dubbo/registry/zookeeper/kazoo_transport.py create mode 100644 dubbo/registry/zookeeper/zk_registry.py rename dubbo/remoting/{transporter.py => _interfaces.py} (55%) create mode 100644 dubbo/remoting/aio/constants.py delete mode 100644 dubbo/serialization.py create mode 100644 dubbo/serialization/__init__.py create mode 100644 dubbo/serialization/_interfaces.py create mode 100644 dubbo/serialization/custom_serializers.py create mode 100644 dubbo/serialization/direct_serializers.py rename dubbo/{config/consumer_config.py => server.py} (67%) delete mode 100644 dubbo/url.py delete mode 100644 tests/logger/__init__.py delete mode 100644 tests/logger/test_logger_factory.py delete mode 100644 tests/logger/test_logging_logger.py diff --git a/dubbo/__init__.py b/dubbo/__init__.py index a5a99ea..bcba37a 100644 --- a/dubbo/__init__.py +++ b/dubbo/__init__.py @@ -13,5 +13,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from ._dubbo import Dubbo diff --git a/dubbo/_dubbo.py b/dubbo/_dubbo.py deleted file mode 100644 index fece509..0000000 --- a/dubbo/_dubbo.py +++ /dev/null @@ -1,176 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import threading -from typing import Dict, List - -from dubbo.config import ApplicationConfig, ConsumerConfig, LoggerConfig, ProtocolConfig -from dubbo.logger.logger_factory import loggerFactory - -logger = loggerFactory.get_logger(__name__) - - -class Dubbo: - - # class variable - _instance = None - _ins_lock = threading.Lock() - - # instance variable - # common - _application: ApplicationConfig - _protocols: Dict[str, ProtocolConfig] - _logger: LoggerConfig - # consumer - _consumer: ConsumerConfig - # provider - # .... - - __slots__ = ["_application", "_protocols", "_logger", "_consumer"] - - def __new__(cls, *args, **kwargs): - # dubbo object is singleton - if cls._instance is None: - with cls._ins_lock: - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self): - # common - self._application = ApplicationConfig.default_config() - self._protocols = {} - self._logger = LoggerConfig.default_config() - # consumer - self._consumer = ConsumerConfig.default_config() - # provider - # TODO add provider config - - # @overload - # def new_client( - # self, reference: str, consumer: Optional[ConsumerConfig] = None - # ) -> Client: ... - # - # @overload - # def new_client( - # self, - # reference: ReferenceConfig, - # consumer: Optional[ConsumerConfig] = None, - # ) -> Client: ... - # - # def new_client( - # self, - # reference: Union[str, ReferenceConfig], - # consumer: Optional[ConsumerConfig] = None, - # ) -> Client: - # """ - # Create a new client - # Args: - # reference: reference value - # consumer: consumer config - # Returns: - # Client: A new instance of Client - # """ - # if isinstance(reference, str): - # reference = ReferenceConfig() - # elif isinstance(reference, ReferenceConfig): - # reference = reference - # else: - # raise TypeError( - # "reference must be a string or an instance of ReferenceConfig" - # ) - # consumer_config = consumer or self._consumer.clone() - # return Client(reference, consumer_config) - - def new_server(self): - """ - Create a new server - """ - pass - - def _init(self): - pass - - def start(self): - pass - - def destroy(self): - pass - - def with_application(self, application_config: ApplicationConfig) -> "Dubbo": - """ - Set application config - Args: - application_config: new application config - Returns: - self: Dubbo instance - """ - if application_config is None or not isinstance( - application_config, ApplicationConfig - ): - raise ValueError("application must be an instance of ApplicationConfig") - self._application = application_config - return self - - def with_protocol(self, protocol_config: ProtocolConfig) -> "Dubbo": - """ - Set protocol config - Args: - protocol_config: new protocol config - Returns: - self: Dubbo instance - """ - if protocol_config is None or not isinstance(protocol_config, ProtocolConfig): - raise ValueError("protocol must be an instance of ProtocolConfig") - self._protocols[protocol_config.name] = protocol_config - return self - - def with_protocols(self, protocol_configs: List[ProtocolConfig]) -> "Dubbo": - """ - Set protocol config - Args: - protocol_configs: new protocol configs - Returns: - self: Dubbo instance - """ - for protocol_config in protocol_configs: - self.with_protocol(protocol_config) - return self - - def with_logger(self, logger_config: LoggerConfig) -> "Dubbo": - """ - Set logger config - Args: - logger_config: new logger config - Returns: - self: Dubbo instance - """ - if logger_config is None or not isinstance(logger_config, LoggerConfig): - raise ValueError("logger must be an instance of LoggerConfig") - self._logger = logger_config - return self - - def with_consumer(self, consumer_config: ConsumerConfig) -> "Dubbo": - """ - Set consumer config - Args: - consumer_config: new consumer config - Returns: - self: Dubbo instance - """ - if consumer_config is None or not isinstance(consumer_config, ConsumerConfig): - raise ValueError("consumer must be an instance of ConsumerConfig") - self._consumer = consumer_config - return self diff --git a/dubbo/client/client.py b/dubbo/client.py similarity index 63% rename from dubbo/client/client.py rename to dubbo/client.py index 6ab37c3..f6e6868 100644 --- a/dubbo/client/client.py +++ b/dubbo/client.py @@ -13,27 +13,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional -from dubbo.callable import RpcCallable -from dubbo.config import ConsumerConfig, ReferenceConfig -from dubbo.constants import common_constants -from dubbo.constants.type_constants import DeserializingFunction, SerializingFunction -from dubbo.logger.logger_factory import loggerFactory -from dubbo.serialization import Serialization +from typing import Optional -logger = loggerFactory.get_logger(__name__) +from dubbo.common import constants as common_constants +from dubbo.common.types import DeserializingFunction, SerializingFunction +from dubbo.config import ReferenceConfig +from dubbo.proxy import RpcCallable +from dubbo.proxy.callables import MultipleRpcCallable class Client: - __slots__ = ["_consumer", "_reference"] + __slots__ = ["_reference"] - def __init__( - self, reference: ReferenceConfig, consumer: Optional[ConsumerConfig] = None - ): + def __init__(self, reference: ReferenceConfig): self._reference = reference - self._consumer = consumer or ConsumerConfig.default_config() def unary( self, @@ -42,7 +37,7 @@ def unary( response_deserializer: Optional[DeserializingFunction] = None, ) -> RpcCallable: return self._callable( - common_constants.CALL_UNARY, + common_constants.UNARY_CALL_VALUE, method_name, request_serializer, response_deserializer, @@ -55,7 +50,7 @@ def client_stream( response_deserializer: Optional[DeserializingFunction] = None, ) -> RpcCallable: return self._callable( - common_constants.CALL_CLIENT_STREAM, + common_constants.CLIENT_STREAM_CALL_VALUE, method_name, request_serializer, response_deserializer, @@ -68,7 +63,7 @@ def server_stream( response_deserializer: Optional[DeserializingFunction] = None, ) -> RpcCallable: return self._callable( - common_constants.CALL_SERVER_STREAM, + common_constants.SERVER_STREAM_CALL_VALUE, method_name, request_serializer, response_deserializer, @@ -81,7 +76,7 @@ def bidi_stream( response_deserializer: Optional[DeserializingFunction] = None, ) -> RpcCallable: return self._callable( - common_constants.CALL_BIDI_STREAM, + common_constants.BI_STREAM_CALL_VALUE, method_name, request_serializer, response_deserializer, @@ -95,26 +90,30 @@ def _callable( response_deserializer: Optional[DeserializingFunction] = None, ) -> RpcCallable: """ - Generate a callable for the given method - Args: - call_type: call type - method_name: method name - request_serializer: request serializer, args: Any, return: bytes - response_deserializer: response deserializer, args: bytes, return: Any - Returns: - RpcCallable: The callable object + Generate a proxy for the given method + :param call_type: The call type. + :type call_type: str + :param method_name: The method name. + :type method_name: str + :param request_serializer: The request serializer. + :type request_serializer: Optional[SerializingFunction] + :param response_deserializer: The response deserializer. + :type response_deserializer: Optional[DeserializingFunction] + :return: The proxy. + :rtype: RpcCallable """ # get invoker invoker = self._reference.get_invoker() url = invoker.get_url() # clone url - url = url.clone_without_attributes() - url.add_parameter(common_constants.METHOD_KEY, method_name) - url.add_parameter(common_constants.CALL_KEY, call_type) + url = url.copy() + url.parameters[common_constants.METHOD_KEY] = method_name + url.parameters[common_constants.CALL_KEY] = call_type - serialization = Serialization(request_serializer, response_deserializer) - url.attributes[common_constants.SERIALIZATION] = serialization + # set serializer and deserializer + url.attributes[common_constants.SERIALIZER_KEY] = request_serializer + url.attributes[common_constants.DESERIALIZER_KEY] = response_deserializer - # create callable - return RpcCallable(invoker, url) + # create proxy + return MultipleRpcCallable(invoker, url) diff --git a/dubbo/protocol/protocol.py b/dubbo/common/__init__.py similarity index 65% rename from dubbo/protocol/protocol.py rename to dubbo/common/__init__.py index 7de46f1..a860593 100644 --- a/dubbo/protocol/protocol.py +++ b/dubbo/common/__init__.py @@ -13,18 +13,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dubbo.protocol.invoker import Invoker -from dubbo.url import URL +from .classes import SingletonBase +from .deliverers import MultiMessageDeliverer, SingleMessageDeliverer +from .node import Node +from .types import DeserializingFunction, SerializingFunction +from .url import URL, create_url -class Protocol: - - def refer(self, url: URL) -> Invoker: - """ - Refer a remote service. - Args: - url (URL): The URL of the remote service. - Returns: - Invoker: The invoker of the remote service. - """ - raise NotImplementedError("refer() is not implemented.") +__all__ = [ + "SingleMessageDeliverer", + "MultiMessageDeliverer", + "URL", + "create_url", + "Node", + "SingletonBase", + "DeserializingFunction", + "SerializingFunction", +] diff --git a/dubbo/compressor/compression.py b/dubbo/common/classes.py similarity index 57% rename from dubbo/compressor/compression.py rename to dubbo/common/classes.py index 342225b..b27c7b9 100644 --- a/dubbo/compressor/compression.py +++ b/dubbo/common/classes.py @@ -14,28 +14,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading -class Compression: +__all__ = ["SingletonBase"] + + +class SingletonBase: """ - Compression interface + Singleton base class. This class ensures that only one instance of a derived class exists. + + This implementation is thread-safe. """ - def compress(self, data: bytes) -> bytes: - """ - Compress the data - Args: - data (bytes): Data to compress - Returns: - bytes: Compressed data - """ - raise NotImplementedError("compress() is not implemented.") + _instance = None + _instance_lock = threading.Lock() - def decompress(self, data: bytes) -> bytes: + def __new__(cls, *args, **kwargs): """ - Decompress the data - Args: - data (bytes): Data to decompress - Returns: - bytes: Decompressed data + Create a new instance of the class if it does not exist. """ - raise NotImplementedError("decompress() is not implemented.") + if cls._instance is None: + with cls._instance_lock: + # double check + if cls._instance is None: + cls._instance = super(SingletonBase, cls).__new__(cls) + return cls._instance diff --git a/dubbo/constants/common_constants.py b/dubbo/common/constants.py similarity index 53% rename from dubbo/constants/common_constants.py rename to dubbo/common/constants.py index cff24c9..33e4f9f 100644 --- a/dubbo/constants/common_constants.py +++ b/dubbo/common/constants.py @@ -14,35 +14,49 @@ # See the License for the specific language governing permissions and # limitations under the License. +PROTOCOL_KEY = "protocol" +TRIPLE = "triple" +TRIPLE_SHORT = "tri" -TRIPLE = "tri" +SIDE_KEY = "side" +SERVER_VALUE = "server" +CLIENT_VALUE = "client" -LOCALHOST_KEY = "localhost" -LOCALHOST_VALUE = "127.0.0.1" +METHOD_KEY = "method" +SERVICE_KEY = "service" -CALL_KEY = "call" -CALL_UNARY = "unary" -CALL_CLIENT_STREAM = "client-stream" -CALL_SERVER_STREAM = "server-stream" -CALL_BIDI_STREAM = "bidi-stream" -ASYNC_KEY = "async" +SERVICE_HANDLER_KEY = "service-handler" -SERIALIZATION = "serialization" +GROUP_KEY = "group" -COMPRESSION = "compression" +LOCAL_HOST_KEY = "localhost" +LOCAL_HOST_VALUE = "127.0.0.1" +DEFAULT_PORT = 50051 -SERVER_KEY = "server" -METHOD_KEY = "method" +SSL_ENABLED_KEY = "ssl-enabled" + +SERIALIZATION_KEY = "serialization" +SERIALIZER_KEY = "serializer" +DESERIALIZER_KEY = "deserializer" -TRUE_VALUE = "true" -FALSE_VALUE = "false" + +COMPRESSION_KEY = "compression" +COMPRESSOR_KEY = "compressor" +DECOMPRESSOR_KEY = "decompressor" -# Constants about the transporter. TRANSPORTER_KEY = "transporter" -TRANSPORTER_SIDE_KEY = "transporter-side" -TRANSPORTER_SIDE_SERVER = "server" -TRANSPORTER_SIDE_CLIENT = "client" -TRANSPORTER_PROTOCOL_KEY = "protocol" -TRANSPORTER_STREAM_HANDLER_KEY = "stream-handler" -TRANSPORTER_ON_CONN_CLOSE_KEY = "on-conn-close" +TRANSPORTER_DEFAULT_VALUE = "aio" + +TRUE_VALUE = "true" +FALSE_VALUE = "false" + +CALL_KEY = "call" +UNARY_CALL_VALUE = "unary" +CLIENT_STREAM_CALL_VALUE = "client-stream" +SERVER_STREAM_CALL_VALUE = "server-stream" +BI_STREAM_CALL_VALUE = "bi-stream" + +PATH_SEPARATOR = "/" +PROTOCOL_SEPARATOR = "://" +DYNAMIC_KEY = "dynamic" diff --git a/dubbo/common/deliverers.py b/dubbo/common/deliverers.py new file mode 100644 index 0000000..67790ec --- /dev/null +++ b/dubbo/common/deliverers.py @@ -0,0 +1,314 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import enum +import queue +import threading +from typing import Any, Optional + +__all__ = ["MessageDeliverer", "SingleMessageDeliverer", "MultiMessageDeliverer"] + + +class DelivererStatus(enum.Enum): + """ + Enumeration for deliverer status. + + Possible statuses: + - PENDING: The deliverer is pending action. + - COMPLETED: The deliverer has completed the action. + - CANCELLED: The action for the deliverer has been cancelled. + - FINISHED: The deliverer has finished all actions and is in a final state. + """ + + PENDING = 0 + COMPLETED = 1 + CANCELLED = 2 + FINISHED = 3 + + @classmethod + def change_allowed( + cls, current_status: "DelivererStatus", target_status: "DelivererStatus" + ) -> bool: + """ + Check if a transition from `current_status` to `target_status` is allowed. + + :param current_status: The current status of the deliverer. + :type current_status: DelivererStatus + :param target_status: The target status to transition to. + :type target_status: DelivererStatus + :return: A boolean indicating if the transition is allowed. + :rtype: bool + """ + # PENDING -> COMPLETED or CANCELLED + if current_status == cls.PENDING: + return target_status in {cls.COMPLETED, cls.CANCELLED} + + # COMPLETED -> FINISHED or CANCELLED + elif current_status == cls.COMPLETED: + return target_status in {cls.FINISHED, cls.CANCELLED} + + # CANCELLED -> FINISHED + elif current_status == cls.CANCELLED: + return target_status == cls.FINISHED + + # FINISHED is the final state, no further transitions allowed + else: + return False + + +class NoMoreMessageError(RuntimeError): + """ + Exception raised when no more messages are available. + """ + + def __init__(self, message: str = "No more message"): + super().__init__(message) + + +class EmptyMessageError(RuntimeError): + """ + Exception raised when the message is empty. + """ + + def __init__(self, message: str = "Message is empty"): + super().__init__(message) + + +class MessageDeliverer(abc.ABC): + """ + Abstract base class for message deliverers. + """ + + __slots__ = ["_status"] + + def __init__(self): + self._status = DelivererStatus.PENDING + + @abc.abstractmethod + def add(self, message: Any) -> None: + """ + Add a message to the deliverer. + + :param message: The message to be added. + :type message: Any + """ + raise NotImplementedError() + + @abc.abstractmethod + def complete(self, message: Any = None) -> None: + """ + Mark the message delivery as complete. + + :param message: The last message (optional). + :type message: Any, optional + """ + raise NotImplementedError() + + @abc.abstractmethod + def cancel(self, exc: Optional[Exception]) -> None: + """ + Cancel the message delivery. + + :param exc: The exception that caused the cancellation. + :type exc: Exception, optional + """ + raise NotImplementedError() + + @abc.abstractmethod + def get(self) -> Any: + """ + Get the next message. + + :return: The next message. + :rtype: Any + :raises NoMoreMessageError: If no more messages are available. + :raises Exception: If the message delivery is cancelled. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_nowait(self) -> Any: + """ + Get the next message without waiting. + + :return: The next message. + :rtype: Any + :raises EmptyMessageError: If the message is empty. + :raises NoMoreMessageError: If no more messages are available. + :raises Exception: If the message delivery is cancelled. + """ + raise NotImplementedError() + + +class SingleMessageDeliverer(MessageDeliverer): + """ + Message deliverer for a single message using a signal-based approach. + """ + + __slots__ = ["_condition", "_message"] + + def __init__(self): + super().__init__() + self._condition = threading.Condition() + self._message: Any = None + + def add(self, message: Any) -> None: + with self._condition: + if self._status is DelivererStatus.PENDING: + # Add the message + self._message = message + + def complete(self, message: Any = None) -> None: + with self._condition: + if DelivererStatus.change_allowed(self._status, DelivererStatus.COMPLETED): + if message is not None: + self._message = message + # update the status + self._status = DelivererStatus.COMPLETED + self._condition.notify_all() + + def cancel(self, exc: Optional[Exception]) -> None: + with self._condition: + if DelivererStatus.change_allowed(self._status, DelivererStatus.CANCELLED): + # Cancel the delivery + self._message = exc or RuntimeError("delivery cancelled.") + self._status = DelivererStatus.CANCELLED + self._condition.notify_all() + + def get(self) -> Any: + with self._condition: + if self._status is DelivererStatus.FINISHED: + raise NoMoreMessageError("Message already consumed.") + + if self._status is DelivererStatus.PENDING: + # If the message is not available, wait + self._condition.wait() + + # check the status + if self._status is DelivererStatus.CANCELLED: + raise self._message + + self._status = DelivererStatus.FINISHED + return self._message + + def get_nowait(self) -> Any: + with self._condition: + if self._status is DelivererStatus.FINISHED: + self._status = DelivererStatus.PENDING + return self._message + + # raise error + if self._status is DelivererStatus.FINISHED: + raise NoMoreMessageError("Message already consumed.") + elif self._status is DelivererStatus.CANCELLED: + raise self._message + elif self._status is DelivererStatus.PENDING: + raise EmptyMessageError("Message is empty") + + +class MultiMessageDeliverer(MessageDeliverer): + """ + Message deliverer supporting multiple messages. + """ + + __slots__ = ["_lock", "_counter", "_messages", "_END_SENTINEL"] + + def __init__(self): + super().__init__() + self._lock = threading.Lock() + self._counter = 0 + self._messages: queue.PriorityQueue[Any] = queue.PriorityQueue() + self._END_SENTINEL = object() + + def add(self, message: Any) -> None: + with self._lock: + if self._status is DelivererStatus.PENDING: + # Add the message + self._counter += 1 + self._messages.put_nowait((self._counter, message)) + + def complete(self, message: Any = None) -> None: + with self._lock: + if DelivererStatus.change_allowed(self._status, DelivererStatus.COMPLETED): + if message is not None: + self._counter += 1 + self._messages.put_nowait((self._counter, message)) + + # Add the end sentinel + self._counter += 1 + self._messages.put_nowait((self._counter, self._END_SENTINEL)) + self._status = DelivererStatus.COMPLETED + + def cancel(self, exc: Optional[Exception]) -> None: + with self._lock: + if DelivererStatus.change_allowed(self._status, DelivererStatus.CANCELLED): + # Set the priority to -1 -> make sure it is the first message + self._messages.put_nowait( + (-1, exc or RuntimeError("delivery cancelled.")) + ) + self._status = DelivererStatus.CANCELLED + + def get(self) -> Any: + if self._status is DelivererStatus.FINISHED: + raise NoMoreMessageError("No more message") + + # block until the message is available + priority, message = self._messages.get() + + # check the status + if self._status is DelivererStatus.CANCELLED: + raise message + elif message is self._END_SENTINEL: + self._status = DelivererStatus.FINISHED + raise NoMoreMessageError("No more message") + else: + return message + + def get_nowait(self) -> Any: + try: + if self._status is DelivererStatus.FINISHED: + raise NoMoreMessageError("No more message") + + priority, message = self._messages.get_nowait() + + # check the status + if self._status is DelivererStatus.CANCELLED: + raise message + elif message is self._END_SENTINEL: + self._status = DelivererStatus.FINISHED + raise NoMoreMessageError("No more message") + else: + return message + except queue.Empty: + raise EmptyMessageError("Message is empty") + + def __iter__(self): + return self + + def __next__(self): + """ + Returns the next request from the queue. + + :return: The next message. + :rtype: Any + :raises StopIteration: If no more messages are available. + """ + while True: + try: + return self.get() + except NoMoreMessageError: + raise StopIteration diff --git a/dubbo/node.py b/dubbo/common/node.py similarity index 64% rename from dubbo/node.py rename to dubbo/common/node.py index f63e12b..a5ec339 100644 --- a/dubbo/node.py +++ b/dubbo/common/node.py @@ -13,32 +13,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dubbo.url import URL +import abc -class Node: +from dubbo.common.url import URL + +__all__ = ["Node"] + + +class Node(abc.ABC): """ - Node + Abstract base class for a Node. """ + @abc.abstractmethod def get_url(self) -> URL: """ - Get the url of the node - Returns: - URL: URL of the node + Get the URL of the node. + + :return: The URL of the node. + :rtype: URL + :raises NotImplementedError: If the method is not implemented. """ raise NotImplementedError("get_url() is not implemented.") + @abc.abstractmethod def is_available(self) -> bool: """ - Check if the node is available - Returns: - bool: True if the node is available, false otherwise + Check if the node is available. + + :return: True if the node is available, False otherwise. + :rtype: bool + :raises NotImplementedError: If the method is not implemented. """ raise NotImplementedError("is_available() is not implemented.") + @abc.abstractmethod def destroy(self) -> None: """ - Destroy the node + Destroy the node. + + :raises NotImplementedError: If the method is not implemented. """ raise NotImplementedError("destroy() is not implemented.") diff --git a/dubbo/constants/type_constants.py b/dubbo/common/types.py similarity index 93% rename from dubbo/constants/type_constants.py rename to dubbo/common/types.py index bb332be..029b837 100644 --- a/dubbo/constants/type_constants.py +++ b/dubbo/common/types.py @@ -13,7 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Any, Callable +__all__ = ["SerializingFunction", "DeserializingFunction"] + SerializingFunction = Callable[[Any], bytes] DeserializingFunction = Callable[[bytes], Any] diff --git a/dubbo/common/url.py b/dubbo/common/url.py new file mode 100644 index 0000000..581fd84 --- /dev/null +++ b/dubbo/common/url.py @@ -0,0 +1,325 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Any, Dict, Optional +from urllib import parse +from urllib.parse import urlencode, urlunparse + +from dubbo.common.constants import PROTOCOL_SEPARATOR + +__all__ = ["URL", "create_url"] + + +def create_url(url: str, encoded: bool = False) -> "URL": + """ + Creates a URL object from a URL string. + + This function takes a URL string and converts it into a URL object. + If the 'encoded' parameter is set to True, the URL string will be decoded before being converted. + + :param url: The URL string to be converted into a URL object. + :type url: str + :param encoded: Determines if the URL string should be decoded before being converted. Defaults to False. + :type encoded: bool + :return: A URL object. + :rtype: URL + :raises ValueError: If the URL format is invalid. + """ + # If the URL is encoded, decode it + if encoded: + url = parse.unquote(url) + + if PROTOCOL_SEPARATOR not in url: + raise ValueError("Invalid URL format: missing protocol") + + parsed_url = parse.urlparse(url) + + if not parsed_url.scheme: + raise ValueError("Invalid URL format: missing scheme.") + + return URL( + parsed_url.scheme, + parsed_url.hostname or "", + parsed_url.port, + parsed_url.username or "", + parsed_url.password or "", + parsed_url.path.lstrip("/"), + {k: v[0] for k, v in parse.parse_qs(parsed_url.query).items()}, + ) + + +class URL: + """ + URL - Uniform Resource Locator. + """ + + __slots__ = [ + "_scheme", + "_host", + "_port", + "_location", + "_username", + "_password", + "_path", + "_parameters", + "_attributes", + ] + + def __init__( + self, + scheme: str, + host: str, + port: Optional[int] = None, + username: str = "", + password: str = "", + path: str = "", + parameters: Optional[Dict[str, str]] = None, + attributes: Optional[Dict[str, Any]] = None, + ): + """ + Initialize the URL object. + + :param scheme: The scheme of the URL (e.g., 'http', 'https'). + :type scheme: str + :param host: The host of the URL. + :type host: str + :param port: The port number of the URL, defaults to None. + :type port: int, optional + :param username: The username for authentication, defaults to an empty string. + :type username: str, optional + :param password: The password for authentication, defaults to an empty string. + :type password: str, optional + :param path: The path of the URL, defaults to an empty string. + :type path: str, optional + :param parameters: The query parameters of the URL as a dictionary, defaults to None. + :type parameters: Dict[str, str], optional + :param attributes: Additional attributes of the URL as a dictionary, defaults to None. + :type attributes: Dict[str, Any], optional + """ + self._scheme = scheme + self._host = host + self._port = port + self._location = f"{host}:{port}" if port else host + self._username = username + self._password = password + self._path = path + self._parameters = parameters or {} + self._attributes = attributes or {} + + @property + def scheme(self) -> str: + """ + Get or set the scheme of the URL. + + :return: The scheme of the URL. + :rtype: str + """ + return self._scheme + + @scheme.setter + def scheme(self, value: str): + self._scheme = value + + @property + def host(self) -> str: + """ + Get or set the host of the URL. + + :return: The host of the URL. + :rtype: str + """ + return self._host + + @host.setter + def host(self, value: str): + self._host = value + self._location = f"{self.host}:{self.port}" if self.port else self.host + + @property + def port(self) -> Optional[int]: + """ + Get or set the port of the URL. + + :return: The port of the URL. + :rtype: int, optional + """ + return self._port + + @port.setter + def port(self, value: int): + if value > 0: + self._port = value + self._location = f"{self.host}:{self.port}" + + @property + def location(self) -> str: + """ + Get or set the location (host:port) of the URL. + + :return: The location of the URL. + :rtype: str + """ + return self._location + + @location.setter + def location(self, value: str): + try: + values = value.split(":") + self.host = values[0] + if len(values) == 2: + self.port = int(values[1]) + except Exception as e: + raise ValueError(f"Invalid location: {value}") from e + + @property + def username(self) -> str: + """ + Get or set the username for authentication. + + :return: The username. + :rtype: str + """ + return self._username + + @username.setter + def username(self, value: str): + self._username = value + + @property + def password(self) -> str: + """ + Get or set the password for authentication. + + :return: The password. + :rtype: str + """ + return self._password + + @password.setter + def password(self, value: str): + self._password = value + + @property + def path(self) -> str: + """ + Get or set the path of the URL. + + :return: The path of the URL. + :rtype: str + """ + return self._path + + @path.setter + def path(self, value: str): + self._path = value.lstrip("/") + + @property + def parameters(self) -> Dict[str, str]: + """ + Get the query parameters of the URL. + + :return: The query parameters as a dictionary. + :rtype: Dict[str, str] + """ + return self._parameters + + @property + def attributes(self) -> Dict[str, Any]: + """ + Get the additional attributes of the URL. + + :return: The attributes as a dictionary. + :rtype: Dict[str, Any] + """ + return self._attributes + + def to_str(self, encode: bool = False) -> str: + """ + Converts the URL to a string. + + :param encode: Determines if the URL should be encoded. Defaults to False. + :type encode: bool + :return: The URL string. + :rtype: str + """ + # Construct the netloc part + if self.username and self.password: + netloc = f"{self.username}:{self.password}@{self.host}" + else: + netloc = self.host + + if self.port: + netloc = f"{netloc}:{self.port}" + + # Convert parameters dictionary to query string + query = urlencode(self.parameters) + + # Construct the URL + url = urlunparse((self.scheme or "", netloc, self.path or "/", "", query, "")) + + if encode: + url = parse.quote(url) + + return url + + def copy(self) -> "URL": + """ + Copy the URL object. + + :return: A shallow copy of the URL object. + :rtype: URL + """ + return copy.copy(self) + + def deepcopy(self) -> "URL": + """ + Deep copy the URL object. + + :return: A deep copy of the URL object. + :rtype: URL + """ + return copy.deepcopy(self) + + def __copy__(self) -> "URL": + return URL( + self.scheme, + self.host, + self.port, + self.username, + self.password, + self.path, + self.parameters.copy(), + self.attributes.copy(), + ) + + def __deepcopy__(self, memo) -> "URL": + return URL( + self.scheme, + self.host, + self.port, + self.username, + self.password, + self.path, + copy.deepcopy(self.parameters, memo), + copy.deepcopy(self.attributes, memo), + ) + + def __str__(self) -> str: + return self.to_str() + + def __repr__(self) -> str: + return self.to_str() diff --git a/dubbo/common/utils.py b/dubbo/common/utils.py new file mode 100644 index 0000000..4b20998 --- /dev/null +++ b/dubbo/common/utils.py @@ -0,0 +1,129 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["EventHelper", "FutureHelper"] + + +class EventHelper: + """ + Helper class for event operations. + """ + + @staticmethod + def is_set(event) -> bool: + """ + Check if the event is set. + + :param event: Event object, you can use threading.Event or any other object that supports the is_set operation. + :type event: Any + :return: True if the event is set, or False if the is_set method is not supported or the event is invalid. + :rtype: bool + """ + return event.is_set() if event and hasattr(event, "is_set") else False + + @staticmethod + def set(event) -> bool: + """ + Attempt to set the event object. + + :param event: Event object, you can use threading.Event or any other object that supports the set operation. + :type event: Any + :return: True if the event was set, False otherwise + (such as the event is invalid or does not support the set operation). + :rtype: bool + """ + if event is None: + return False + + # If the event supports the set operation, set the event and return True + if hasattr(event, "set"): + event.set() + return True + + # If the event is invalid or does not support the set operation, return False + return False + + @staticmethod + def clear(event) -> bool: + """ + Attempt to clear the event object. + + :param event: Event object, you can use threading.Event or any other object that supports the clear operation. + :type event: Any + :return: True if the event was cleared, False otherwise + (such as the event is invalid or does not support the clear operation). + :rtype: bool + """ + if not event: + return False + + # If the event supports the clear operation, clear the event and return True + if hasattr(event, "clear"): + event.clear() + return True + + # If the event is invalid or does not support the clear operation, return False + return False + + +class FutureHelper: + """ + Helper class for future operations. + """ + + @staticmethod + def done(future) -> bool: + """ + Check if the future is done. + + :param future: Future object + :type future: Any + :return: True if the future is done, False otherwise. + :rtype: bool + """ + return future.done() if future and hasattr(future, "done") else False + + @staticmethod + def set_result(future, result): + """ + Set the result of the future. + + :param future: Future object + :type future: Any + :param result: Result to set + :type result: Any + """ + if not future or FutureHelper.done(future): + return + + if hasattr(future, "set_result"): + future.set_result(result) + + @staticmethod + def set_exception(future, exception): + """ + Set the exception to the future. + + :param future: Future object + :type future: Any + :param exception: Exception to set + :type exception: Exception + """ + if not future or FutureHelper.done(future): + return + + if hasattr(future, "set_exception"): + future.set_exception(exception) diff --git a/dubbo/compression/__init__.py b/dubbo/compression/__init__.py new file mode 100644 index 0000000..eb01689 --- /dev/null +++ b/dubbo/compression/__init__.py @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._interfaces import Compressor, Decompressor +from .bzip2s import Bzip2 +from .gzips import Gzip +from .identities import Identity + +__all__ = ["Compressor", "Decompressor", "Identity", "Gzip", "Bzip2"] diff --git a/dubbo/compression/_interfaces.py b/dubbo/compression/_interfaces.py new file mode 100644 index 0000000..d7a8513 --- /dev/null +++ b/dubbo/compression/_interfaces.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +__all__ = ["MessageEncoding", "Compressor", "Decompressor"] + + +class MessageEncoding(abc.ABC): + """ + The message encoding interface. + """ + + @classmethod + @abc.abstractmethod + def get_message_encoding(cls) -> str: + """ + Get message encoding of current compression + :return: The message encoding. + :rtype: str + """ + raise NotImplementedError() + + +class Compressor(MessageEncoding, abc.ABC): + """ + The compression interface. + """ + + @abc.abstractmethod + def compress(self, data: bytes) -> bytes: + """ + Compress the data. + :param data: The data to compress. + :type data: bytes + :return: The compressed data. + :rtype: bytes + """ + raise NotImplementedError() + + +class Decompressor(MessageEncoding, abc.ABC): + """ + The decompressor interface. + """ + + @abc.abstractmethod + def decompress(self, data: bytes) -> bytes: + """ + Decompress the data. + :param data: The data to decompress. + :type data: bytes + :return: The decompressed data. + :rtype: bytes + """ + raise NotImplementedError() diff --git a/dubbo/compression/bzip2s.py b/dubbo/compression/bzip2s.py new file mode 100644 index 0000000..92b2bf0 --- /dev/null +++ b/dubbo/compression/bzip2s.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bz2 + +from dubbo.compression import Compressor, Decompressor + + +class Bzip2(Compressor, Decompressor): + """ + The BZIP2 compression and decompressor. + """ + + _MESSAGE_ENCODING = "bzip2" + + @classmethod + def get_message_encoding(cls) -> str: + """ + Get message encoding of current compression + :return: The message encoding. + :rtype: str + """ + return cls._MESSAGE_ENCODING + + def compress(self, data: bytes) -> bytes: + """ + Compress the data. + :param data: The data to compress. + :type data: bytes + :return: The compressed data. + :rtype: bytes + """ + return bz2.compress(data) + + def decompress(self, data: bytes) -> bytes: + """ + Decompress the data. + :param data: The data to decompress. + :type data: bytes + :return: The decompressed data. + :rtype: bytes + """ + return bz2.decompress(data) diff --git a/dubbo/compressor/gzip_compression.py b/dubbo/compression/gzips.py similarity index 58% rename from dubbo/compressor/gzip_compression.py rename to dubbo/compression/gzips.py index 803bd55..4b9ac59 100644 --- a/dubbo/compressor/gzip_compression.py +++ b/dubbo/compression/gzips.py @@ -13,32 +13,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import gzip -from dubbo.compressor.compression import Compression +from dubbo.compression import Compressor, Decompressor + +__all__ = ["Gzip"] -class GzipCompression(Compression): +class Gzip(Compressor, Decompressor): """ - GZIP Compression implementation + The GZIP compression and decompressor. """ + _MESSAGE_ENCODING = "gzip" + + @classmethod + def get_message_encoding(cls) -> str: + """ + Get message encoding of current compression + :return: The message encoding. + :rtype: str + """ + return cls._MESSAGE_ENCODING + def compress(self, data: bytes) -> bytes: """ - Compress the data using GZIP - Args: - data (bytes): Data to compress - Returns: - bytes: Compressed data + Compress the data. + :param data: The data to compress. + :type data: bytes + :return: The compressed data. + :rtype: bytes """ return gzip.compress(data) def decompress(self, data: bytes) -> bytes: """ - Decompress the data using GZIP - Args: - data (bytes): Data to decompress - Returns: - bytes: Decompressed data + Decompress the data. + :param data: The data to decompress. + :type data: bytes + :return: The decompressed data. + :rtype: bytes """ return gzip.decompress(data) diff --git a/dubbo/compression/identities.py b/dubbo/compression/identities.py new file mode 100644 index 0000000..0d039b3 --- /dev/null +++ b/dubbo/compression/identities.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dubbo.common import SingletonBase +from dubbo.compression import Compressor, Decompressor + +__all__ = ["Identity"] + + +class Identity(Compressor, Decompressor, SingletonBase): + """ + The identity compression and decompressor.It does not compress or decompress the data. + """ + + _MESSAGE_ENCODING = "identity" + + @classmethod + def get_message_encoding(cls) -> str: + """ + Get message encoding of current compression + :return: The message encoding. + :rtype: str + """ + return cls._MESSAGE_ENCODING + + def compress(self, data: bytes) -> bytes: + """ + Compress the data. + :param data: The data to compress. + :type data: bytes + :return: The compressed data. + :rtype: bytes + """ + return data + + def decompress(self, data: bytes) -> bytes: + """ + Decompress the data. + :param data: The data to decompress. + :type data: bytes + :return: The decompressed data. + :rtype: bytes + """ + return data diff --git a/dubbo/config/__init__.py b/dubbo/config/__init__.py index 63d9535..63c4ec1 100644 --- a/dubbo/config/__init__.py +++ b/dubbo/config/__init__.py @@ -13,8 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from .application_config import ApplicationConfig -from .consumer_config import ConsumerConfig from .logger_config import FileLoggerConfig, LoggerConfig from .protocol_config import ProtocolConfig from .reference_config import ReferenceConfig diff --git a/dubbo/config/logger_config.py b/dubbo/config/logger_config.py index dfdf8ab..f34ce13 100644 --- a/dubbo/config/logger_config.py +++ b/dubbo/config/logger_config.py @@ -13,15 +13,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from dataclasses import dataclass from typing import Dict, Optional -from dubbo.constants import logger_constants as logger_constants -from dubbo.constants.logger_constants import FileRotateType, Level +from dubbo.common.url import URL from dubbo.extension import extensionLoader from dubbo.logger import LoggerAdapter -from dubbo.logger.logger_factory import loggerFactory -from dubbo.url import URL +from dubbo.logger import constants as logger_constants +from dubbo.logger import loggerFactory +from dubbo.logger.constants import Level @dataclass @@ -39,7 +40,7 @@ class FileLoggerConfig: """ - rotate: FileRotateType = FileRotateType.NONE + rotate: logger_constants.FileRotateType = logger_constants.FileRotateType.NONE file_formatter: Optional[str] = None file_dir: str = logger_constants.DEFAULT_FILE_DIR_VALUE file_name: str = logger_constants.DEFAULT_FILE_NAME_VALUE @@ -48,9 +49,9 @@ class FileLoggerConfig: interval: int = logger_constants.DEFAULT_FILE_INTERVAL_VALUE def check(self) -> None: - if self.rotate == FileRotateType.SIZE and self.max_bytes < 0: + if self.rotate == logger_constants.FileRotateType.SIZE and self.max_bytes < 0: raise ValueError("Max bytes can't be less than 0") - elif self.rotate == FileRotateType.TIME and self.interval < 1: + elif self.rotate == logger_constants.FileRotateType.TIME and self.interval < 1: raise ValueError("Interval can't be less than 1") def dict(self) -> Dict[str, str]: diff --git a/dubbo/config/reference_config.py b/dubbo/config/reference_config.py index 1e1530d..a7f258c 100644 --- a/dubbo/config/reference_config.py +++ b/dubbo/config/reference_config.py @@ -13,13 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import threading from typing import Optional, Union +from dubbo.common import URL, create_url from dubbo.extension import extensionLoader -from dubbo.protocol.invoker import Invoker -from dubbo.protocol.protocol import Protocol -from dubbo.url import URL +from dubbo.protocol import Invoker, Protocol class ReferenceConfig: @@ -36,7 +36,7 @@ class ReferenceConfig: def __init__(self, url: Union[str, URL], service_name: str): self._initialized = False self._global_lock = threading.Lock() - self._url: URL = url if isinstance(url, URL) else URL.value_of(url) + self._url: URL = url if isinstance(url, URL) else create_url(url) self._service_name = service_name self._protocol: Optional[Protocol] = None self._invoker: Optional[Invoker] = None diff --git a/dubbo/config/service_config.py b/dubbo/config/service_config.py new file mode 100644 index 0000000..a4f3644 --- /dev/null +++ b/dubbo/config/service_config.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from dubbo.common import URL +from dubbo.common import constants as common_constants +from dubbo.extension import extensionLoader +from dubbo.protocol import Protocol +from dubbo.proxy.handlers import RpcServiceHandler + +__all__ = ["ServiceConfig"] + + +class ServiceConfig: + """ + Service configuration + """ + + def __init__( + self, + service_handler: RpcServiceHandler, + port: Optional[int] = None, + protocol: Optional[str] = None, + ): + + self._service_handler = service_handler + self._port = port or common_constants.DEFAULT_PORT + + protocol_str = protocol or common_constants.TRIPLE_SHORT + + self._export_url = URL( + protocol_str, common_constants.LOCAL_HOST_KEY, self._port + ) + self._export_url.attributes[common_constants.SERVICE_HANDLER_KEY] = ( + service_handler + ) + + self._protocol: Protocol = extensionLoader.get_extension( + Protocol, protocol_str + )(self._export_url) + + self._exported = False + self._exporting = False + + def export(self): + """ + Export service + """ + if self._exporting or self._exported: + return + + self._exporting = True + try: + self._protocol.export(self._export_url) + self._exported = True + finally: + self._exporting = False diff --git a/dubbo/extension/__init__.py b/dubbo/extension/__init__.py index 0da2118..50859ba 100644 --- a/dubbo/extension/__init__.py +++ b/dubbo/extension/__init__.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dubbo.extension.extension_loader import ExtensionError from dubbo.extension.extension_loader import ExtensionLoader as _ExtensionLoader extensionLoader = _ExtensionLoader() diff --git a/dubbo/extension/extension_loader.py b/dubbo/extension/extension_loader.py index 3c96040..7ec801d 100644 --- a/dubbo/extension/extension_loader.py +++ b/dubbo/extension/extension_loader.py @@ -13,77 +13,82 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import importlib -import threading from typing import Any -from dubbo.extension import registry -from dubbo.logger.logger_factory import loggerFactory +from dubbo.common import SingletonBase +from dubbo.extension import registries as registries_module -logger = loggerFactory.get_logger(__name__) +class ExtensionError(Exception): + """ + Extension error. + """ -class ExtensionLoader: + def __init__(self, message: str): + """ + Initialize the extension error. + :param message: The error message. + :type message: str + """ + super().__init__(message) - _instance = None - _ins_lock = threading.Lock() - def __new__(cls, *args, **kwargs): - if cls._instance is None: - with cls._ins_lock: - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance +class ExtensionLoader(SingletonBase): + """ + Singleton class for loading extension implementations. + """ def __init__(self): - self._registries = registry.get_all_extended_registry() + """ + Initialize the extension loader. + + Load all the registries from the registries module. + """ + if not hasattr(self, "_initialized"): # Ensure __init__ runs only once + self._registries = {} + for name in registries_module.__all__: + registry = getattr(registries_module, name) + self._registries[registry.interface] = registry.impls + self._initialized = True - def get_extension(self, superclass: Any, name: str) -> Any: - # Get the registry for the extension - extension_impls = self._registries.get(superclass) - err_msg = None - if not extension_impls: - err_msg = f"Extension {superclass} is not registered." - logger.error(err_msg) - raise ValueError(err_msg) + def get_extension(self, interface: Any, impl_name: str) -> Any: + """ + Get the extension implementation for the interface. - # Get the full name of the class -> module.class - full_name = extension_impls.get(name) + :param interface: Interface class. + :type interface: Any + :param impl_name: Implementation name. + :type impl_name: str + :return: Extension implementation class. + :rtype: Any + :raises ExtensionError: If the interface or implementation is not found. + """ + # Get the registry for the interface + impls = self._registries.get(interface) + if not impls: + raise ExtensionError(f"Interface '{interface.__name__}' is not supported.") + + # Get the full name of the implementation + full_name = impls.get(impl_name) if not full_name: - err_msg = f"Extension {superclass} with name {name} is not registered." - logger.error(err_msg) - raise ValueError(err_msg) + raise ExtensionError( + f"Implementation '{impl_name}' for interface '{interface.__name__}' is not exist." + ) - module_name = class_name = None try: # Split the full name into module and class module_name, class_name = full_name.rsplit(".", 1) - # Load the module + + # Load the module and get the class module = importlib.import_module(module_name) - # Get the class from the module subclass = getattr(module, class_name) - if subclass: - # Check if the class is a subclass of the extension - if issubclass(subclass, superclass) and subclass is not superclass: - # Return the class - return subclass - else: - err_msg = f"Class {class_name} does not inherit from {superclass}." - else: - err_msg = f"Class {class_name} not found in module {module_name}" - if err_msg: - # If there is an error message, raise an exception - raise Exception(err_msg) - except ImportError as e: - logger.exception(f"Module {module_name} could not be imported.") - raise e - except AttributeError as e: - logger.exception(f"Class {class_name} not found in {module_name}.") - raise e + # Return the subclass + return subclass except Exception as e: - if err_msg: - logger.error(err_msg) - else: - logger.exception(f"An error occurred while loading {full_name}.") - raise e + raise ExtensionError( + f"Failed to load extension '{impl_name}' for interface '{interface.__name__}'. \n" + f"Detail: {e}" + ) diff --git a/dubbo/extension/registry.py b/dubbo/extension/registries.py similarity index 54% rename from dubbo/extension/registry.py rename to dubbo/extension/registries.py index dac28ed..32a5c24 100644 --- a/dubbo/extension/registry.py +++ b/dubbo/extension/registries.py @@ -13,48 +13,72 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect -import sys + from dataclasses import dataclass -from typing import Any +from typing import Any, Dict -from dubbo.compressor.compression import Compression +from dubbo.compression import Compressor, Decompressor from dubbo.logger import LoggerAdapter -from dubbo.protocol.protocol import Protocol -from dubbo.remoting.transporter import Transporter +from dubbo.protocol import Protocol +from dubbo.remoting import Transporter @dataclass class ExtendedRegistry: """ A dataclass to represent an extended registry. - Attributes: - interface: Any -> The interface of the registry. - impls: dict[str, Any] -> A dict of implementations of the interface. -> {name: impl} + + :param interface: The interface of the registry. + :type interface: Any + :param impls: The implementations of the registry. + :type impls: Dict[str, Any] """ interface: Any - impls: dict[str, Any] + impls: Dict[str, Any] + + +# All Extension Registries +__all__ = [ + "protocolRegistry", + "compressorRegistry", + "decompressorRegistry", + "transporterRegistry", + "loggerAdapterRegistry", +] -"""Protocol registry.""" +# Protocol registry protocolRegistry = ExtendedRegistry( interface=Protocol, impls={ - "tri": "dubbo.protocol.triple.tri_protocol.TripleProtocol", + "tri": "dubbo.protocol.triple.protocol.TripleProtocol", }, ) -"""Compression registry.""" -compressionRegistry = ExtendedRegistry( - interface=Compression, +# Compressor registry +compressorRegistry = ExtendedRegistry( + interface=Compressor, impls={ - "gzip": "dubbo.compressor.gzip_compression.GzipCompression", + "identity": "dubbo.compression.Identity", + "gzip": "dubbo.compression.Gzip", + "bzip2": "dubbo.compression.Bzip2", }, ) -"""Transporter registry.""" +# Decompressor registry +decompressorRegistry = ExtendedRegistry( + interface=Decompressor, + impls={ + "identity": "dubbo.compression.Identity", + "gzip": "dubbo.compression.Gzip", + "bzip2": "dubbo.compression.Bzip2", + }, +) + + +# Transporter registry transporterRegistry = ExtendedRegistry( interface=Transporter, impls={ @@ -63,23 +87,10 @@ class ExtendedRegistry: ) -"""LoggerAdapter registry.""" +# Logger Adapter registry loggerAdapterRegistry = ExtendedRegistry( interface=LoggerAdapter, impls={ "logging": "dubbo.logger.logging.logger_adapter.LoggingLoggerAdapter", }, ) - - -def get_all_extended_registry() -> dict[Any, dict[str, Any]]: - """ - Get all extended registries in the current module. - :return: A dict of all extended registries. -> {interface: {name: impl}} - """ - current_module = sys.modules[__name__] - registries: dict[Any, dict[str, Any]] = {} - for name, obj in inspect.getmembers(current_module): - if isinstance(obj, ExtendedRegistry): - registries[obj.interface] = obj.impls - return registries diff --git a/dubbo/constants/__init__.py b/dubbo/loadbalance/__init__.py similarity index 92% rename from dubbo/constants/__init__.py rename to dubbo/loadbalance/__init__.py index bcba37a..ba98b36 100644 --- a/dubbo/constants/__init__.py +++ b/dubbo/loadbalance/__init__.py @@ -13,3 +13,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from ._interfaces import AbstractLoadBalance, LoadBalance diff --git a/dubbo/loadbalance/_interfaces.py b/dubbo/loadbalance/_interfaces.py new file mode 100644 index 0000000..dfbf85d --- /dev/null +++ b/dubbo/loadbalance/_interfaces.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import List, Optional + +from dubbo.common import URL +from dubbo.protocol import Invocation, Invoker + + +class LoadBalance(abc.ABC): + """ + The load balance interface. + """ + + @abc.abstractmethod + def select( + self, invokers: List[Invoker], url: URL, invocation: Invocation + ) -> Optional[Invoker]: + """ + Select an invoker from the list. + :param invokers: The invokers. + :type invokers: List[Invoker] + :param url: The URL. + :type url: URL + :param invocation: The invocation. + :type invocation: Invocation + :return: The selected invoker. If no invoker is selected, return None. + :rtype: Optional[Invoker] + """ + raise NotImplementedError() + + +class AbstractLoadBalance(LoadBalance, abc.ABC): + """ + The abstract load balance. + """ + + def select( + self, invokers: List[Invoker], url: URL, invocation: Invocation + ) -> Optional[Invoker]: + if not invokers: + return None + + if len(invokers) == 1: + return invokers[0] + + return self.do_select(invokers, url, invocation) + + @abc.abstractmethod + def do_select( + self, invokers: List[Invoker], url: URL, invocation: Invocation + ) -> Optional[Invoker]: + """ + Do select an invoker from the list. + :param invokers: The invokers. + :type invokers: List[Invoker] + :param url: The URL. + :type url: URL + :param invocation: The invocation. + :type invocation: Invocation + :return: The selected invoker. If no invoker is selected, return None. + :rtype: Optional[Invoker] + """ + raise NotImplementedError() diff --git a/dubbo/logger/__init__.py b/dubbo/logger/__init__.py index c7bee10..4f42594 100644 --- a/dubbo/logger/__init__.py +++ b/dubbo/logger/__init__.py @@ -14,4 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .logger import Logger, LoggerAdapter +from ._interfaces import Logger, LoggerAdapter +from .logger_factory import LoggerFactory as _LoggerFactory + +# The logger factory instance. +loggerFactory = _LoggerFactory() + +__all__ = ["Logger", "LoggerAdapter", "loggerFactory"] diff --git a/dubbo/logger/_interfaces.py b/dubbo/logger/_interfaces.py new file mode 100644 index 0000000..88fa999 --- /dev/null +++ b/dubbo/logger/_interfaces.py @@ -0,0 +1,204 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Any + +from dubbo.common.url import URL + +from .constants import Level + +_all__ = ["Logger", "LoggerAdapter"] + + +class Logger(abc.ABC): + """ + Logger Interface, which is used to log messages. + """ + + @abc.abstractmethod + def log(self, level: Level, msg: str, *args: Any, **kwargs: Any) -> None: + """ + Log a message at the specified logging level. + + :param level: The logging level. + :type level: Level + :param msg: The log message. + :type msg: str + :param args: Additional positional arguments. + :type args: Any + :param kwargs: Additional keyword arguments. + :type kwargs: Any + """ + raise NotImplementedError() + + @abc.abstractmethod + def debug(self, msg: str, *args, **kwargs) -> None: + """ + Log a debug message. + + :param msg: The debug message. + :type msg: str + :param args: Additional positional arguments. + :type args: Any + :param kwargs: Additional keyword arguments. + :type kwargs: Any + """ + raise NotImplementedError() + + @abc.abstractmethod + def info(self, msg: str, *args, **kwargs) -> None: + """ + Log an info message. + + :param msg: The info message. + :type msg: str + :param args: Additional positional arguments. + :type args: Any + :param kwargs: Additional keyword arguments. + :type kwargs: Any + """ + raise NotImplementedError() + + @abc.abstractmethod + def warning(self, msg: str, *args, **kwargs) -> None: + """ + Log a warning message. + + :param msg: The warning message. + :type msg: str + :param args: Additional positional arguments. + :type args: Any + :param kwargs: Additional keyword arguments. + :type kwargs: Any + """ + raise NotImplementedError() + + @abc.abstractmethod + def error(self, msg: str, *args, **kwargs) -> None: + """ + Log an error message. + + :param msg: The error message. + :type msg: str + :param args: Additional positional arguments. + :type args: Any + :param kwargs: Additional keyword arguments. + :type kwargs: Any + """ + raise NotImplementedError() + + @abc.abstractmethod + def critical(self, msg: str, *args, **kwargs) -> None: + """ + Log a critical message. + + :param msg: The critical message. + :type msg: str + :param args: Additional positional arguments. + :type args: Any + :param kwargs: Additional keyword arguments. + :type kwargs: Any + """ + raise NotImplementedError() + + @abc.abstractmethod + def fatal(self, msg: str, *args, **kwargs) -> None: + """ + Log a fatal message. + + :param msg: The fatal message. + :type msg: str + :param args: Additional positional arguments. + :type args: Any + :param kwargs: Additional keyword arguments. + :type kwargs: Any + """ + raise NotImplementedError() + + @abc.abstractmethod + def exception(self, msg: str, *args, **kwargs) -> None: + """ + Log an exception message. + + :param msg: The exception message. + :type msg: str + :param args: Additional positional arguments. + :type args: Any + :param kwargs: Additional keyword arguments. + :type kwargs: Any + """ + raise NotImplementedError() + + @abc.abstractmethod + def is_enabled_for(self, level: Level) -> bool: + """ + Check if this logger is enabled for the specified level. + + :param level: The logging level. + :type level: Level + :return: Whether the logging level is enabled. + :rtype: bool + """ + raise NotImplementedError() + + +class LoggerAdapter(abc.ABC): + """ + Logger Adapter Interface, which is used to support different logging libraries. + """ + + __slots__ = ["_config"] + + def __init__(self, config: URL): + """ + Initialize the logger adapter. + + :param config: The configuration of the logger adapter. + :type config: URL + """ + self._config = config + + def get_logger(self, name: str) -> Logger: + """ + Get a logger by name. + + :param name: The name of the logger. + :type name: str + :return: An instance of the logger. + :rtype: Logger + """ + raise NotImplementedError() + + @property + def level(self) -> Level: + """ + Get the current logging level. + + :return: The current logging level. + :rtype: Level + """ + raise NotImplementedError() + + @level.setter + def level(self, level: Level) -> None: + """ + Set the logging level. + + :param level: The logging level to set. + :type level: Level + """ + raise NotImplementedError() diff --git a/dubbo/constants/logger_constants.py b/dubbo/logger/constants.py similarity index 64% rename from dubbo/constants/logger_constants.py rename to dubbo/logger/constants.py index 40ae17e..a6cae5d 100644 --- a/dubbo/constants/logger_constants.py +++ b/dubbo/logger/constants.py @@ -13,15 +13,47 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import enum import os -from functools import cache + +__all__ = [ + "Level", + "FileRotateType", + "LEVEL_KEY", + "DRIVER_KEY", + "CONSOLE_ENABLED_KEY", + "FILE_ENABLED_KEY", + "FILE_DIR_KEY", + "FILE_NAME_KEY", + "FILE_ROTATE_KEY", + "FILE_MAX_BYTES_KEY", + "FILE_INTERVAL_KEY", + "FILE_BACKUP_COUNT_KEY", + "DEFAULT_DRIVER_VALUE", + "DEFAULT_LEVEL_VALUE", + "DEFAULT_CONSOLE_ENABLED_VALUE", + "DEFAULT_FILE_ENABLED_VALUE", + "DEFAULT_FILE_DIR_VALUE", + "DEFAULT_FILE_NAME_VALUE", + "DEFAULT_FILE_MAX_BYTES_VALUE", + "DEFAULT_FILE_INTERVAL_VALUE", + "DEFAULT_FILE_BACKUP_COUNT_VALUE", +] @enum.unique class Level(enum.Enum): """ The logging level enum. + + :cvar DEBUG: Debug level. + :cvar INFO: Info level. + :cvar WARNING: Warning level. + :cvar ERROR: Error level. + :cvar CRITICAL: Critical level. + :cvar FATAL: Fatal level. + :cvar UNKNOWN: Unknown level. """ DEBUG = "DEBUG" @@ -30,28 +62,37 @@ class Level(enum.Enum): ERROR = "ERROR" CRITICAL = "CRITICAL" FATAL = "FATAL" + UNKNOWN = "UNKNOWN" @classmethod - @cache def get_level(cls, level_value: str) -> "Level": + """ + Get the level from the level value. + + :param level_value: The level value. + :type level_value: str + :return: The level. If the level value is invalid, return UNKNOWN. + :rtype: Level + """ level_value = level_value.upper() for level in cls: if level_value == level.value: return level - raise ValueError("Log level invalid") + return cls.UNKNOWN @enum.unique class FileRotateType(enum.Enum): """ The file rotating type enum. + + :cvar NONE: No rotating. + :cvar SIZE: Rotate the file by size. + :cvar TIME: Rotate the file by time. """ - # No rotating. NONE = "NONE" - # Rotate the file by size. SIZE = "SIZE" - # Rotate the file by time. TIME = "TIME" diff --git a/dubbo/logger/logger.py b/dubbo/logger/logger.py deleted file mode 100644 index 00607a8..0000000 --- a/dubbo/logger/logger.py +++ /dev/null @@ -1,175 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any - -from dubbo.constants.logger_constants import Level -from dubbo.url import URL - - -class Logger: - """ - Logger Interface, which is used to log messages. - """ - - def log(self, level: Level, msg: str, *args: Any, **kwargs: Any) -> None: - """ - Log a message at the specified logging level. - - Args: - level (Level): The logging level. - msg (str): The log message. - *args (Any): Additional positional arguments. - **kwargs (Any): Additional keyword arguments. - """ - raise NotImplementedError("log() is not implemented.") - - def debug(self, msg: str, *args, **kwargs) -> None: - """ - Log a debug message. - - Args: - msg (str): The debug message. - *args (Any): Additional positional arguments. - **kwargs (Any): Additional keyword arguments. - """ - raise NotImplementedError("debug() is not implemented.") - - def info(self, msg: str, *args, **kwargs) -> None: - """ - Log an info message. - - Args: - msg (str): The info message. - *args (Any): Additional positional arguments. - **kwargs (Any): Additional keyword arguments. - """ - raise NotImplementedError("info() is not implemented.") - - def warning(self, msg: str, *args, **kwargs) -> None: - """ - Log a warning message. - - Args: - msg (str): The warning message. - *args (Any): Additional positional arguments. - **kwargs (Any): Additional keyword arguments. - """ - raise NotImplementedError("warning() is not implemented.") - - def error(self, msg: str, *args, **kwargs) -> None: - """ - Log an error message. - - Args: - msg (str): The error message. - *args (Any): Additional positional arguments. - **kwargs (Any): Additional keyword arguments. - """ - raise NotImplementedError("error() is not implemented.") - - def critical(self, msg: str, *args, **kwargs) -> None: - """ - Log a critical message. - - Args: - msg (str): The critical message. - *args (Any): Additional positional arguments. - **kwargs (Any): Additional keyword arguments. - """ - raise NotImplementedError("critical() is not implemented.") - - def fatal(self, msg: str, *args, **kwargs) -> None: - """ - Log a fatal message. - - Args: - msg (str): The fatal message. - *args (Any): Additional positional arguments. - **kwargs (Any): Additional keyword arguments. - """ - raise NotImplementedError("fatal() is not implemented.") - - def exception(self, msg: str, *args, **kwargs) -> None: - """ - Log an exception message. - - Args: - msg (str): The exception message. - *args (Any): Additional positional arguments. - **kwargs (Any): Additional keyword arguments. - """ - raise NotImplementedError("exception() is not implemented.") - - def is_enabled_for(self, level: Level) -> bool: - """ - Is this logger enabled for level 'level'? - Args: - level (Level): The logging level. - Return: - bool: Whether the logging level is enabled. - """ - raise ValueError("is_enabled_for() is not implemented.") - - -class LoggerAdapter: - """ - Logger Adapter Interface, which is used to support different logging libraries. - Attributes: - _config(URL): logger adapter configuration. - """ - - _config: URL - - def __init__(self, config: URL): - """ - Initialize the logger adapter. - - Args: - config(URL): config (URL): The config of the logger adapter. - """ - self._config = config - - def get_logger(self, name: str) -> Logger: - """ - Get a logger by name. - - Args: - name (str): The name of the logger. - - Returns: - Logger: An instance of the logger. - """ - raise NotImplementedError("get_logger() is not implemented.") - - @property - def level(self) -> Level: - """ - Get the current logging level. - - Returns: - Level: The current logging level. - """ - raise NotImplementedError("get_level() is not implemented.") - - @level.setter - def level(self, level: Level) -> None: - """ - Set the logging level. - - Args: - level (Level): The logging level to set. - """ - raise NotImplementedError("set_level() is not implemented.") diff --git a/dubbo/logger/logger_factory.py b/dubbo/logger/logger_factory.py index 59a291b..0a7d0b2 100644 --- a/dubbo/logger/logger_factory.py +++ b/dubbo/logger/logger_factory.py @@ -13,17 +13,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import threading -from typing import Dict +from typing import Dict, Optional + +from dubbo.common import SingletonBase +from dubbo.common.url import URL +from dubbo.logger import Logger, LoggerAdapter +from dubbo.logger import constants as logger_constants +from dubbo.logger.constants import Level -from dubbo.constants import logger_constants as logger_constants -from dubbo.constants.logger_constants import Level -from dubbo.logger.logger import Logger, LoggerAdapter -from dubbo.logger.logging.logger_adapter import LoggingLoggerAdapter -from dubbo.url import URL +__all__ = ["LoggerFactory"] # Default logger config with default values. -_default_config = URL( +_DEFAULT_CONFIG = URL( scheme=logger_constants.DEFAULT_DRIVER_VALUE, host=logger_constants.DEFAULT_LEVEL_VALUE.value, parameters={ @@ -39,85 +42,86 @@ ) -class _LoggerFactory: - """ - LoggerFactory - Attributes: - _logger_adapter (LoggerAdapter): The logger adapter. - _loggers (Dict[str, Logger]): The logger cache. - _loggers_lock (threading.Lock): The logger lock to protect the logger cache. +class LoggerFactory(SingletonBase): """ + Singleton factory class for creating and managing loggers. - _logger_adapter = LoggingLoggerAdapter(_default_config) - _loggers: Dict[str, Logger] = {} - _loggers_lock = threading.Lock() + This class ensures a single instance of the logger factory, provides methods to set and get + logger adapters, and manages logger instances. + """ - @classmethod - def set_logger_adapter(cls, logger_adapter) -> None: - """ - Set logger config + def __init__(self): """ - cls._logger_adapter = logger_adapter - cls._loggers_lock.acquire() - try: - # update all loggers - cls._loggers = { - name: cls._logger_adapter.get_logger(name) for name in cls._loggers - } - finally: - cls._loggers_lock.release() + Initialize the logger factory. - @classmethod - def get_logger_adapter(cls) -> LoggerAdapter: + This method sets up the internal lock, logger adapter, and logger cache. """ - Get the logger adapter. + self._lock = threading.RLock() + self._logger_adapter: Optional[LoggerAdapter] = None + self._loggers: Dict[str, Logger] = {} - Returns: - LoggerAdapter: The current logger adapter. + def _ensure_logger_adapter(self) -> None: """ - return cls._logger_adapter + Ensure the logger adapter is set. - @classmethod - def get_logger(cls, name: str) -> Logger: + If the logger adapter is not set, this method sets it to the default adapter. """ - Get the logger by name. + if not self._logger_adapter: + with self._lock: + if not self._logger_adapter: + # Import here to avoid circular imports + from dubbo.logger.logging.logger_adapter import LoggingLoggerAdapter - Args: - name (str): The name of the logger to retrieve. + self.set_logger_adapter(LoggingLoggerAdapter(_DEFAULT_CONFIG)) - Returns: - Logger: An instance of the requested logger. + def set_logger_adapter(self, logger_adapter: LoggerAdapter) -> None: """ - logger = cls._loggers.get(name) - if not logger: - cls._loggers_lock.acquire() - try: - if name not in cls._loggers: - cls._loggers[name] = cls._logger_adapter.get_logger(name) - logger = cls._loggers[name] - finally: - cls._loggers_lock.release() - return logger + Set the logger adapter. - @classmethod - def get_level(cls) -> Level: + :param logger_adapter: The new logger adapter to use. + :type logger_adapter: LoggerAdapter """ - Get the current logging level. + with self._lock: + self._logger_adapter = logger_adapter + # Update all loggers + self._loggers = { + name: self._logger_adapter.get_logger(name) for name in self._loggers + } - Returns: - Level: The current logging level. + def get_logger_adapter(self) -> LoggerAdapter: """ - return cls._logger_adapter.level + Get the current logger adapter. - @classmethod - def set_level(cls, level: Level) -> None: + :return: The current logger adapter. + :rtype: LoggerAdapter """ - Set the logging level. + self._ensure_logger_adapter() + return self._logger_adapter - Args: - level (Level): The logging level to set. + def get_logger(self, name: str) -> Logger: """ - cls._logger_adapter.level = level + Get the logger by name. + :param name: The name of the logger to retrieve. + :type name: str + :return: An instance of the requested logger. + :rtype: Logger + """ + self._ensure_logger_adapter() + logger = self._loggers.get(name) + if not logger: + with self._lock: + if name not in self._loggers: + self._loggers[name] = self._logger_adapter.get_logger(name) + logger = self._loggers[name] + return logger -loggerFactory = _LoggerFactory + def get_level(self) -> Level: + """ + Get the current logging level. + + :return: The current logging level. + :rtype: Level + """ + self._ensure_logger_adapter() + return self._logger_adapter.level diff --git a/dubbo/logger/logging/__init__.py b/dubbo/logger/logging/__init__.py index d8765ff..10e45eb 100644 --- a/dubbo/logger/logging/__init__.py +++ b/dubbo/logger/logging/__init__.py @@ -15,3 +15,5 @@ # limitations under the License. from .logger_adapter import LoggerAdapter + +__all__ = ["LoggerAdapter"] diff --git a/dubbo/logger/logging/formatter.py b/dubbo/logger/logging/formatter.py index 56a002a..1dc409e 100644 --- a/dubbo/logger/logging/formatter.py +++ b/dubbo/logger/logging/formatter.py @@ -13,10 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging import re from enum import Enum +__all__ = ["ColorFormatter", "NoColorFormatter", "Colors"] + class Colors(Enum): """ diff --git a/dubbo/logger/logging/logger.py b/dubbo/logger/logging/logger.py index 8fcb929..d8feb77 100644 --- a/dubbo/logger/logging/logger.py +++ b/dubbo/logger/logging/logger.py @@ -17,11 +17,14 @@ import logging from typing import Dict -from dubbo.constants.logger_constants import Level from dubbo.logger import Logger +from ..constants import Level + +__all__ = ["LoggingLogger"] + # The mapping from the logging level to the logging level. -_level_map: Dict[Level, int] = { +LEVEL_MAP: Dict[Level, int] = { Level.DEBUG: logging.DEBUG, Level.INFO: logging.INFO, Level.WARNING: logging.WARNING, @@ -30,26 +33,38 @@ Level.FATAL: logging.FATAL, } +STACKLEVEL_KEY = "stacklevel" +STACKLEVEL_DEFAULT = 1 +STACKLEVEL_OFFSET = 2 + +EXC_INFO_KEY = "exc_info" +EXC_INFO_DEFAULT = True + class LoggingLogger(Logger): """ The logging logger implementation. - Attributes: - _logger (logging.Logger): The real working logger object """ - _logger: logging.Logger + __slots__ = ["_logger"] def __init__(self, internal_logger: logging.Logger): + """ + Initialize the logger. + :param internal_logger: The internal logger. + :type internal_logger: logging + """ self._logger = internal_logger def _log(self, level: int, msg: str, *args, **kwargs) -> None: # Add the stacklevel to the keyword arguments. - kwargs["stacklevel"] = kwargs.get("stacklevel", 1) + 2 + kwargs[STACKLEVEL_KEY] = ( + kwargs.get(STACKLEVEL_KEY, STACKLEVEL_DEFAULT) + STACKLEVEL_OFFSET + ) self._logger.log(level, msg, *args, **kwargs) def log(self, level: Level, msg: str, *args, **kwargs) -> None: - self._log(_level_map[level], msg, *args, **kwargs) + self._log(LEVEL_MAP[level], msg, *args, **kwargs) def debug(self, msg: str, *args, **kwargs) -> None: self._log(logging.DEBUG, msg, *args, **kwargs) @@ -70,10 +85,10 @@ def fatal(self, msg: str, *args, **kwargs) -> None: self._log(logging.FATAL, msg, *args, **kwargs) def exception(self, msg: str, *args, **kwargs) -> None: - if kwargs.get("exc_info") is None: - kwargs["exc_info"] = True + if kwargs.get(EXC_INFO_KEY) is None: + kwargs[EXC_INFO_KEY] = EXC_INFO_DEFAULT self.error(msg, *args, **kwargs) def is_enabled_for(self, level: Level) -> bool: - logging_level = _level_map.get(level) + logging_level = LEVEL_MAP.get(level) return self._logger.isEnabledFor(logging_level) if logging_level else False diff --git a/dubbo/logger/logging/logger_adapter.py b/dubbo/logger/logging/logger_adapter.py index f4d36b4..3e60813 100644 --- a/dubbo/logger/logging/logger_adapter.py +++ b/dubbo/logger/logging/logger_adapter.py @@ -20,58 +20,67 @@ from functools import cache from logging import handlers -from dubbo.constants import common_constants -from dubbo.constants import logger_constants as logger_constants -from dubbo.constants.logger_constants import FileRotateType, Level +from dubbo.common import constants as common_constants +from dubbo.common.url import URL from dubbo.logger import Logger, LoggerAdapter +from dubbo.logger import constants as logger_constants +from dubbo.logger.constants import LEVEL_KEY, Level from dubbo.logger.logging import formatter from dubbo.logger.logging.logger import LoggingLogger -from dubbo.url import URL """This module provides the logging logger implementation. -> logging module""" +__all__ = ["LoggingLoggerAdapter"] + class LoggingLoggerAdapter(LoggerAdapter): """ - Internal logger adapter.Responsible for logging logger creation, encapsulated the logging.getLogger() method - Attributes: - _level(Level): logging level. + Internal logger adapter responsible for creating loggers and encapsulating the logging.getLogger() method. """ - _level: Level + __slots__ = ["_level"] def __init__(self, config: URL): + """ + Initialize the LoggingLoggerAdapter with the given configuration. + + :param config: The configuration URL for the logger adapter. + :type config: URL + """ super().__init__(config) # Set level - level_name = config.get_parameter(logger_constants.LEVEL_KEY) + level_name = config.parameters.get(LEVEL_KEY) self._level = Level.get_level(level_name) if level_name else Level.DEBUG self._update_level() def get_logger(self, name: str) -> Logger: """ Create a logger instance by name. - Args: - name (str): The logger name. - Returns: - Logger: The InternalLogger instance. + + :param name: The logger name. + :type name: str + :return: An instance of the logger. + :rtype: Logger """ logger_instance = logging.getLogger(name) # clean up handlers logger_instance.handlers.clear() # Add console handler - console_enabled = self._config.get_parameter( - logger_constants.CONSOLE_ENABLED_KEY - ) or str(logger_constants.DEFAULT_CONSOLE_ENABLED_VALUE) + console_enabled = self._config.parameters.get( + logger_constants.CONSOLE_ENABLED_KEY, + str(logger_constants.DEFAULT_CONSOLE_ENABLED_VALUE), + ) if console_enabled.lower() == common_constants.TRUE_VALUE or bool( sys.stdout and sys.stdout.isatty() ): logger_instance.addHandler(self._get_console_handler()) # Add file handler - file_enabled = self._config.get_parameter( - logger_constants.FILE_ENABLED_KEY - ) or str(logger_constants.DEFAULT_FILE_ENABLED_VALUE) + file_enabled = self._config.parameters.get( + logger_constants.FILE_ENABLED_KEY, + str(logger_constants.DEFAULT_FILE_ENABLED_VALUE), + ) if file_enabled.lower() == common_constants.TRUE_VALUE: logger_instance.addHandler(self._get_file_handler()) @@ -84,9 +93,10 @@ def get_logger(self, name: str) -> Logger: @cache def _get_console_handler(self) -> logging.StreamHandler: """ - Get the console handler.(Avoid duplicate consoleHandler creation with @cache) - Returns: - logging.StreamHandler: The console handler. + Get the console handler, avoiding duplicate creation with caching. + + :return: The console handler. + :rtype: logging.StreamHandler """ console_handler = logging.StreamHandler() console_handler.setFormatter(formatter.ColorFormatter()) @@ -96,39 +106,41 @@ def _get_console_handler(self) -> logging.StreamHandler: @cache def _get_file_handler(self) -> logging.Handler: """ - Get the file handler.(Avoid duplicate fileHandler creation with @cache) - Returns: - logging.Handler: The file handler. + Get the file handler, avoiding duplicate creation with caching. + + :return: The file handler. + :rtype: logging.Handler """ # Get file path - file_dir = self._config.get_parameter(logger_constants.FILE_DIR_KEY) - file_name = ( - self._config.get_parameter(logger_constants.FILE_NAME_KEY) - or logger_constants.DEFAULT_FILE_NAME_VALUE + file_dir = self._config.parameters.get(logger_constants.FILE_DIR_KEY) + file_name = self._config.parameters.get( + logger_constants.FILE_NAME_KEY, logger_constants.DEFAULT_FILE_NAME_VALUE ) file_path = os.path.join(file_dir, file_name) # Get backup count backup_count = int( - self._config.get_parameter(logger_constants.FILE_BACKUP_COUNT_KEY) - or logger_constants.DEFAULT_FILE_BACKUP_COUNT_VALUE + self._config.parameters.get( + logger_constants.FILE_BACKUP_COUNT_KEY, + logger_constants.DEFAULT_FILE_BACKUP_COUNT_VALUE, + ) ) # Get rotate type - rotate_type = self._config.get_parameter(logger_constants.FILE_ROTATE_KEY) + rotate_type = self._config.parameters.get(logger_constants.FILE_ROTATE_KEY) # Set file Handler file_handler: logging.Handler - if rotate_type == FileRotateType.SIZE.value: + if rotate_type == logger_constants.FileRotateType.SIZE.value: # Set RotatingFileHandler max_bytes = int( - self._config.get_parameter(logger_constants.FILE_MAX_BYTES_KEY) + self._config.parameters.get(logger_constants.FILE_MAX_BYTES_KEY) ) file_handler = handlers.RotatingFileHandler( file_path, maxBytes=max_bytes, backupCount=backup_count ) - elif rotate_type == FileRotateType.TIME.value: + elif rotate_type == logger_constants.FileRotateType.TIME.value: # Set TimedRotatingFileHandler interval = int( - self._config.get_parameter(logger_constants.FILE_INTERVAL_KEY) + self._config.parameters.get(logger_constants.FILE_INTERVAL_KEY) ) file_handler = handlers.TimedRotatingFileHandler( file_path, interval=interval, backupCount=backup_count @@ -145,8 +157,9 @@ def _get_file_handler(self) -> logging.Handler: def level(self) -> Level: """ Get the logging level. - Returns: - Level: The logging level. + + :return: The current logging level. + :rtype: Level """ return self._level @@ -154,8 +167,9 @@ def level(self) -> Level: def level(self, level: Level) -> None: """ Set the logging level. - Args: - level (Level): The logging level. + + :param level: The logging level to set. + :type level: Level """ if level == self._level or level is None: return @@ -164,10 +178,9 @@ def level(self, level: Level) -> None: def _update_level(self): """ - Update log level. - Complete the log level change by modifying the root logger + Update the log level by modifying the root logger. """ # Get the root logger root_logger = logging.getLogger() # Set the logging level - root_logger.setLevel(self._level.name) + root_logger.setLevel(self._level.value) diff --git a/dubbo/protocol/__init__.py b/dubbo/protocol/__init__.py index bcba37a..965b73f 100644 --- a/dubbo/protocol/__init__.py +++ b/dubbo/protocol/__init__.py @@ -13,3 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from ._interfaces import Invocation, Invoker, Protocol, Result + +__all__ = ["Invocation", "Result", "Invoker", "Protocol"] diff --git a/dubbo/protocol/_interfaces.py b/dubbo/protocol/_interfaces.py new file mode 100644 index 0000000..b3ba210 --- /dev/null +++ b/dubbo/protocol/_interfaces.py @@ -0,0 +1,121 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Any + +from dubbo.common.node import Node +from dubbo.common.url import URL + +__all__ = ["Invocation", "Result", "Invoker", "Protocol"] + + +class Invocation(abc.ABC): + + @abc.abstractmethod + def get_service_name(self) -> str: + """ + Get the service name. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_method_name(self) -> str: + """ + Get the method name. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_argument(self) -> Any: + """ + Get the method argument. + """ + raise NotImplementedError() + + +class Result(abc.ABC): + """ + Result of a call + """ + + @abc.abstractmethod + def set_value(self, value: Any) -> None: + """ + Set the value of the result + Args: + value: Value to set + """ + raise NotImplementedError() + + @abc.abstractmethod + def value(self) -> Any: + """ + Get the value of the result + """ + raise NotImplementedError() + + @abc.abstractmethod + def set_exception(self, exception: Exception) -> None: + """ + Set the exception to the result + Args: + exception: Exception to set + """ + raise NotImplementedError() + + @abc.abstractmethod + def exception(self) -> Exception: + """ + Get the exception to the result + """ + raise NotImplementedError() + + +class Invoker(Node, abc.ABC): + """ + Invoker + """ + + @abc.abstractmethod + def invoke(self, invocation: Invocation) -> Result: + """ + Invoke the service. + Returns: + Result: The result of the invocation. + """ + raise NotImplementedError() + + +class Protocol(abc.ABC): + + @abc.abstractmethod + def export(self, url: URL): + """ + Export a remote service. + """ + raise NotImplementedError() + + @abc.abstractmethod + def refer(self, url: URL) -> Invoker: + """ + Refer a remote service. + Args: + url (URL): The URL of the remote service. + Returns: + Invoker: The invoker of the remote service. + """ + raise NotImplementedError() diff --git a/dubbo/protocol/invocation.py b/dubbo/protocol/invocation.py index 59f3b03..a3ac662 100644 --- a/dubbo/protocol/invocation.py +++ b/dubbo/protocol/invocation.py @@ -13,28 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional - - -class Invocation: - - def get_service_name(self) -> str: - """ - Get the service name. - """ - raise NotImplementedError("get_service_name() is not implemented.") - def get_method_name(self) -> str: - """ - Get the method name. - """ - raise NotImplementedError("get_method_name() is not implemented.") +from typing import Any, Dict, Optional - def get_argument(self) -> Any: - """ - Get the method argument. - """ - raise NotImplementedError("get_args() is not implemented.") +from ._interfaces import Invocation class RpcInvocation(Invocation): @@ -48,6 +30,14 @@ class RpcInvocation(Invocation): attributes (Optional[Dict[str, Any]]): Only used on the caller side, will not appear on the wire. """ + __slots__ = [ + "_service_name", + "_method_name", + "_argument", + "_attachments", + "_attributes", + ] + def __init__( self, service_name: str, @@ -63,49 +53,18 @@ def __init__( self._attributes = attributes or {} def add_attachment(self, key: str, value: str) -> None: - """ - Add an attachment to the invocation. - Args: - key (str): The key of the attachment. - value (str): The value of the attachment. - """ self._attachments[key] = value def get_attachment(self, key: str) -> Optional[str]: - """ - Get the attachment of the invocation. - Args: - key (str): The key of the attachment. - Returns: - The value of the attachment. If the attachment does not exist, return None. - """ return self._attachments.get(key, None) def add_attribute(self, key: str, value: Any) -> None: - """ - Add an attribute to the invocation. - Args: - key (str): The key of the attribute. - value (Any): The value of the attribute. - """ self._attributes[key] = value def get_attribute(self, key: str) -> Optional[Any]: - """ - Get the attribute of the invocation. - Args: - key (str): The key of the attribute. - Returns: - The value of the attribute. If the attribute does not exist, return None. - """ return self._attributes.get(key, None) def get_service_name(self) -> str: - """ - Get the service name. - Returns: - The service name. - """ return self._service_name def get_method_name(self) -> str: diff --git a/dubbo/protocol/invoker.py b/dubbo/protocol/invoker.py deleted file mode 100644 index 763372f..0000000 --- a/dubbo/protocol/invoker.py +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dubbo.node import Node -from dubbo.protocol.invocation import Invocation -from dubbo.protocol.result import Result - - -class Invoker(Node): - - def get_interface(self): - """ - Get service interface. - """ - raise NotImplementedError("get_interface() is not implemented.") - - def invoke(self, invocation: Invocation) -> Result: - """ - Invoke the service. - Returns: - Result: The result of the invocation. - """ - raise NotImplementedError("invoke() is not implemented.") diff --git a/dubbo/protocol/result.py b/dubbo/protocol/result.py deleted file mode 100644 index c263baf..0000000 --- a/dubbo/protocol/result.py +++ /dev/null @@ -1,67 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any - - -class Result: - """ - Result of a call - """ - - def set_value(self, value: Any) -> None: - """ - Set the value of the result - Args: - value: Value to set - """ - raise NotImplementedError("set_value() is not implemented.") - - def value(self) -> Any: - """ - Get the value of the result - """ - raise NotImplementedError("get_value() is not implemented.") - - def set_exception(self, exception: Exception) -> None: - """ - Set the exception to the result - Args: - exception: Exception to set - """ - raise NotImplementedError("set_exception() is not implemented.") - - def exception(self) -> Exception: - """ - Get the exception to the result - """ - raise NotImplementedError("get_exception() is not implemented.") - - def add_attachment(self, key: str, value: Any) -> None: - """ - Add an attachment to the result - Args: - key: Key of the attachment - value: Value of the attachment - """ - raise NotImplementedError("add_attachment() is not implemented.") - - def get_attachment(self, key: str) -> Any: - """ - Get an attachment from the result - Args: - key: Key of the attachment - """ - raise NotImplementedError("get_attachment() is not implemented.") diff --git a/dubbo/protocol/triple/call/__init__.py b/dubbo/protocol/triple/call/__init__.py new file mode 100644 index 0000000..d274978 --- /dev/null +++ b/dubbo/protocol/triple/call/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._interfaces import ClientCall, ServerCall +from .client_call import TripleClientCall + +__all__ = ["ClientCall", "ServerCall", "TripleClientCall"] diff --git a/dubbo/protocol/triple/call/_interfaces.py b/dubbo/protocol/triple/call/_interfaces.py new file mode 100644 index 0000000..08764c8 --- /dev/null +++ b/dubbo/protocol/triple/call/_interfaces.py @@ -0,0 +1,143 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Any, Dict + +from dubbo.protocol.triple.metadata import RequestMetadata +from dubbo.protocol.triple.status import TriRpcStatus + +__all__ = ["ClientCall", "ServerCall"] + + +class ClientCall(abc.ABC): + """ + Interface for client call. + """ + + @abc.abstractmethod + def start(self, metadata: RequestMetadata) -> None: + """ + Start this call. + + :param metadata: call metadata + :type metadata: RequestMetadata + """ + raise NotImplementedError() + + @abc.abstractmethod + def send_message(self, message: Any, last: bool) -> None: + """ + Send message to server. + + :param message: message to send + :type message: Any + :param last: whether this message is the last one + :type last: bool + """ + raise NotImplementedError() + + @abc.abstractmethod + def cancel_by_local(self, e: Exception) -> None: + """ + Cancel this call by local. + + :param e: The exception that caused the call to be canceled + :type e: Exception + """ + raise NotImplementedError() + + class Listener(abc.ABC): + """ + Interface for client call listener. + """ + + @abc.abstractmethod + def on_message(self, message: Any) -> None: + """ + Called when a message is received from server. + + :param message: received message + :type message: Any + """ + raise NotImplementedError() + + @abc.abstractmethod + def on_close(self, status: TriRpcStatus, trailers: Dict[str, Any]) -> None: + """ + Called when the call is closed. + + :param status: call status + :type status: TriRpcStatus + :param trailers: trailers + :type trailers: Dict[str, Any] + """ + raise NotImplementedError() + + +class ServerCall(abc.ABC): + """ + Interface for server call. + """ + + @abc.abstractmethod + def send_message(self, message: Any) -> None: + """ + Send message to client. + + :param message: message to send + :type message: Any + """ + raise NotImplementedError() + + @abc.abstractmethod + def complete(self, status: TriRpcStatus, attachments: Dict[str, Any]) -> None: + """ + Complete this call. + + :param status: call status + :type status: TriRpcStatus + :param attachments: attachments + :type attachments: Dict[str, Any] + """ + raise NotImplementedError() + + class Listener(abc.ABC): + """ + Interface for server call listener. + """ + + @abc.abstractmethod + def on_message(self, message: Any, last: bool) -> None: + """ + Called when a message is received from client. + + :param message: received message + :type message: Any + :param last: whether this message is the last one + :type last: bool + """ + raise NotImplementedError() + + @abc.abstractmethod + def on_close(self, status: TriRpcStatus) -> None: + """ + Called when the call is closed. + + :param status: call status + :type status: TriRpcStatus + """ + raise NotImplementedError() diff --git a/dubbo/protocol/triple/call/client_call.py b/dubbo/protocol/triple/call/client_call.py new file mode 100644 index 0000000..c9700b0 --- /dev/null +++ b/dubbo/protocol/triple/call/client_call.py @@ -0,0 +1,178 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +from dubbo.compression import Compressor, Identity +from dubbo.logger import loggerFactory +from dubbo.protocol.triple.call import ClientCall +from dubbo.protocol.triple.constants import GRpcCode +from dubbo.protocol.triple.metadata import RequestMetadata +from dubbo.protocol.triple.results import TriResult +from dubbo.protocol.triple.status import TriRpcStatus +from dubbo.protocol.triple.stream import ClientStream +from dubbo.protocol.triple.stream.client_stream import TriClientStream +from dubbo.remoting.aio.http2.stream_handler import StreamClientMultiplexHandler +from dubbo.serialization import Deserializer, SerializationError, Serializer + +__all__ = ["TripleClientCall", "DefaultClientCallListener"] + +_LOGGER = loggerFactory.get_logger(__name__) + + +class TripleClientCall(ClientCall, ClientStream.Listener): + """ + Triple client call. + """ + + def __init__( + self, + stream_factory: StreamClientMultiplexHandler, + listener: ClientCall.Listener, + serializer: Serializer, + deserializer: Deserializer, + ): + self._stream_factory = stream_factory + self._client_stream: Optional[ClientStream] = None + self._listener = listener + self._serializer = serializer + self._deserializer = deserializer + self._compressor: Optional[Compressor] = None + + self._headers_sent = False + self._done = False + self._request_metadata: Optional[RequestMetadata] = None + + def start(self, metadata: RequestMetadata) -> None: + self._request_metadata = metadata + + # get compression from metadata + self._compressor = metadata.compressor + + # create a new stream + client_stream = TriClientStream(self, self._compressor) + h2_stream = self._stream_factory.create(client_stream.transport_listener) + client_stream.bind(h2_stream) + self._client_stream = client_stream + + def send_message(self, message: Any, last: bool) -> None: + if self._done: + _LOGGER.warning("Call is done, cannot send message") + return + + # check if headers are sent + if not self._headers_sent: + # send headers + self._headers_sent = True + self._client_stream.send_headers(self._request_metadata.to_headers()) + + # send message + try: + data = self._serializer.serialize(message) + compress_flag = ( + 0 + if self._compressor.get_message_encoding() + == Identity.get_message_encoding() + else 1 + ) + self._client_stream.send_message(data, compress_flag, last) + except SerializationError as e: + _LOGGER.error("Failed to serialize message: %s", e) + # close the stream + self.cancel_by_local(e) + # close the listener + status = TriRpcStatus( + code=GRpcCode.INTERNAL, + description="Failed to serialize message", + ) + self._listener.on_close(status, {}) + + def cancel_by_local(self, e: Exception) -> None: + if self._done: + return + self._done = True + + if not self._client_stream or not self._headers_sent: + return + + status = TriRpcStatus( + code=GRpcCode.CANCELLED, + description=f"Call cancelled by client: {e}", + ) + self._client_stream.cancel_by_local(status) + + def on_message(self, data: bytes) -> None: + """ + Called when a message is received from server. + :param data: The message data + :type data: bytes + """ + if self._done: + _LOGGER.warning(f"Received message after call is done, data: {data}") + return + + try: + # Deserialize the message + message = self._deserializer.deserialize(data) + self._listener.on_message(message) + except SerializationError as e: + _LOGGER.error("Failed to deserialize message: %s", e) + # close the stream + self.cancel_by_local(e) + # close the listener + status = TriRpcStatus( + code=GRpcCode.INTERNAL, + description="Failed to deserialize message", + ) + self._listener.on_close(status, {}) + + def on_complete(self, status: TriRpcStatus, attachments: Dict[str, Any]) -> None: + """ + Called when the call is completed. + :param status: The status + :type status: TriRpcStatus + :param attachments: The attachments + :type attachments: Dict[str, Any] + """ + if not self._done: + self._done = True + self._listener.on_close(status, attachments) + + def on_cancel_by_remote(self, status: TriRpcStatus) -> None: + """ + Called when the call is cancelled by remote. + :param status: The status + :type status: TriRpcStatus + """ + self.on_complete(status, {}) + + +class DefaultClientCallListener(ClientCall.Listener): + """ + The default client call listener. + """ + + def __init__(self, result: TriResult): + self._result = result + + def on_message(self, message: Any) -> None: + self._result.set_value(message) + + def on_close(self, status: TriRpcStatus, attachments: Dict[str, Any]) -> None: + if status.code != GRpcCode.OK: + self._result.set_exception(status.as_exception()) + else: + self._result.complete_value() diff --git a/dubbo/protocol/triple/call/server_call.py b/dubbo/protocol/triple/call/server_call.py new file mode 100644 index 0000000..7b96207 --- /dev/null +++ b/dubbo/protocol/triple/call/server_call.py @@ -0,0 +1,268 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable, Dict, Optional + +from dubbo.common import constants as common_constants +from dubbo.common.deliverers import ( + MessageDeliverer, + MultiMessageDeliverer, + SingleMessageDeliverer, +) +from dubbo.protocol.triple.call import ServerCall +from dubbo.protocol.triple.constants import ( + GRpcCode, + TripleHeaderName, + TripleHeaderValue, +) +from dubbo.protocol.triple.status import TriRpcStatus +from dubbo.protocol.triple.stream import ServerStream +from dubbo.proxy.handlers import RpcMethodHandler +from dubbo.remoting.aio.http2.headers import Http2Headers +from dubbo.remoting.aio.http2.registries import HttpStatus +from dubbo.serialization import ( + CustomDeserializer, + CustomSerializer, + DirectDeserializer, + DirectSerializer, +) + + +class TripleServerCall(ServerCall, ServerStream.Listener): + + def __init__(self, stream: ServerStream, method_handler: RpcMethodHandler): + self._stream = stream + self._method_runner: MethodRunner = MethodRunnerFactory.create( + method_handler, self + ) + + self._executor: Optional[ThreadPoolExecutor] = None + + # get serializer + serializing_function = method_handler.response_serializer + self._serializer = ( + CustomSerializer(serializing_function) + if serializing_function + else DirectSerializer() + ) + + # get deserializer + deserializing_function = method_handler.request_serializer + self._deserializer = ( + CustomDeserializer(deserializing_function) + if deserializing_function + else DirectDeserializer() + ) + + self._headers_sent = False + + def send_message(self, message: Any) -> None: + if not self._headers_sent: + headers = Http2Headers() + headers.status = HttpStatus.OK.value + headers.add( + TripleHeaderName.CONTENT_TYPE.value, + TripleHeaderValue.APPLICATION_GRPC_PROTO.value, + ) + self._stream.send_headers(headers) + + serialized_data = self._serializer.serialize(message) + # TODO support compression + self._stream.send_message(serialized_data, False) + + def complete(self, status: TriRpcStatus, attachments: Dict[str, Any]) -> None: + if not attachments.get(TripleHeaderName.CONTENT_TYPE.value): + attachments[TripleHeaderName.CONTENT_TYPE.value] = ( + TripleHeaderValue.APPLICATION_GRPC_PROTO.value + ) + self._stream.complete(status, attachments) + + def on_headers(self, headers: Dict[str, Any]) -> None: + # start a new thread to run the method + self._executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="dubbo-tri-method-" + ) + self._executor.submit(self._method_runner.run) + + def on_message(self, data: bytes) -> None: + deserialized_data = self._deserializer.deserialize(data) + self._method_runner.receive_arg(deserialized_data) + + def on_complete(self) -> None: + self._method_runner.receive_complete() + + def on_cancel_by_remote(self, status: TriRpcStatus) -> None: + # cancel the method runner. + self._executor.shutdown() + self._executor = None + + +class MethodRunner(abc.ABC): + """ + Interface for method runner. + """ + + @abc.abstractmethod + def receive_arg(self, arg: Any) -> None: + """ + Receive argument. + :param arg: argument + :type arg: Any + """ + raise NotImplementedError() + + @abc.abstractmethod + def receive_complete(self) -> None: + """ + Receive complete. + """ + raise NotImplementedError() + + @abc.abstractmethod + def run(self) -> None: + """ + Run the method. + """ + raise NotImplementedError() + + @abc.abstractmethod + def handle_result(self, result: Any) -> None: + """ + Handle the result. + :param result: result + :type result: Any + """ + raise NotImplementedError() + + @abc.abstractmethod + def handle_exception(self, e: Exception) -> None: + """ + Handle the exception. + :param e: exception. + :type e: Exception + """ + raise NotImplementedError() + + +class DefaultMethodRunner(MethodRunner): + """ + Abstract method runner. + """ + + def __init__( + self, + func: Callable, + server_call: TripleServerCall, + client_stream: bool, + server_stream: bool, + ): + + self._server_call: TripleServerCall = server_call + self._func = func + + self._deliverer: MessageDeliverer = ( + MultiMessageDeliverer() if client_stream else SingleMessageDeliverer() + ) + self._server_stream = server_stream + + self._completed = False + + def receive_arg(self, arg: Any) -> None: + self._deliverer.add(arg) + + def receive_complete(self) -> None: + self._deliverer.complete() + + def run(self) -> None: + try: + if isinstance(self._deliverer, SingleMessageDeliverer): + result = self._func(self._deliverer.get()) + else: + result = self._func(self._deliverer) + # handle the result + self.handle_result(result) + except Exception as e: + # handle the exception + self.handle_exception(e) + + def handle_result(self, result: Any) -> None: + try: + if not self._server_stream: + # get single result + self._server_call.send_message(result) + else: + # get multi results + for message in result: + self._server_call.send_message(message) + + self._server_call.complete(TriRpcStatus(GRpcCode.OK), {}) + self._completed = True + except Exception as e: + self.handle_exception(e) + + def handle_exception(self, e: Exception) -> None: + if not self._completed: + status = TriRpcStatus( + GRpcCode.INTERNAL, + description=f"Invoke method failed: {str(e)}", + cause=e, + ) + self._server_call.complete(status, {}) + self._completed = True + + +class MethodRunnerFactory: + """ + Factory for method runner. + """ + + @staticmethod + def create(method_handler: RpcMethodHandler, server_call) -> MethodRunner: + """ + Create a method runner. + + :param method_handler: method handler + :type method_handler: RpcMethodHandler + :param server_call: server call + :type server_call: TripleServerCall + :return: method runner + :rtype: MethodRunner + """ + client_stream = ( + True + if method_handler.call_type + in [ + common_constants.CLIENT_STREAM_CALL_VALUE, + common_constants.BI_STREAM_CALL_VALUE, + ] + else False + ) + + server_stream = ( + True + if method_handler.call_type + in [ + common_constants.SERVER_STREAM_CALL_VALUE, + common_constants.BI_STREAM_CALL_VALUE, + ] + else False + ) + + return DefaultMethodRunner( + method_handler.behavior, server_call, client_stream, server_stream + ) diff --git a/dubbo/protocol/triple/client/calls.py b/dubbo/protocol/triple/client/calls.py deleted file mode 100644 index 2e6a184..0000000 --- a/dubbo/protocol/triple/client/calls.py +++ /dev/null @@ -1,156 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, List, Optional, Tuple - -from dubbo.compressor.compression import Compression -from dubbo.protocol.triple.tri_codec import TriEncoder -from dubbo.protocol.triple.tri_results import AbstractTriResult -from dubbo.protocol.triple.tri_status import TriRpcStatus -from dubbo.remoting.aio.http2.headers import Http2Headers -from dubbo.remoting.aio.http2.registries import Http2ErrorCode -from dubbo.remoting.aio.http2.stream import Http2Stream -from dubbo.serialization import Serialization - - -class ClientCall: - """ - The client call. - """ - - def __init__(self, listener: "ClientCall.Listener"): - self._listener = listener - self._stream: Optional[Http2Stream] = None - - def bind_stream(self, stream: Http2Stream) -> None: - """ - Bind stream - """ - self._stream = stream - - def send_headers(self, headers: Http2Headers) -> None: - """ - Send headers. - Args: - headers: The headers. - """ - raise NotImplementedError("send_headers() is not implemented.") - - def send_message(self, message: Any, last: bool = False) -> None: - """ - Send message. - Args: - message: The message. - last: Whether this is the last message. - """ - raise NotImplementedError("send_message() is not implemented.") - - def send_reset(self, error_code: Http2ErrorCode) -> None: - """ - Send a reset. - Args: - error_code: The error code. - """ - raise NotImplementedError("send_reset() is not implemented.") - - class Listener: - """ - The listener of the client call. - """ - - def on_message(self, message: Any) -> None: - """ - Called when a message is received. - """ - raise NotImplementedError("on_message() is not implemented.") - - def on_close( - self, rpc_status: TriRpcStatus, trailers: List[Tuple[str, str]] - ) -> None: - """ - Called when the stream is closed. - """ - raise NotImplementedError("on_close() is not implemented.") - - -class TriClientCall(ClientCall): - """ - The triple client call. - """ - - def __init__( - self, - result: AbstractTriResult, - serialization: Serialization, - compression: Optional[Compression] = None, - ): - super().__init__(TriClientCall.Listener(result, serialization)) - self._serialization = serialization - self._tri_encoder = TriEncoder(compression) - - @property - def listener(self) -> "TriClientCall.Listener": - return self._listener - - def send_headers(self, headers: Http2Headers) -> None: - """ - Send headers. - """ - self._stream.send_headers(headers, end_stream=False) - - def send_message(self, message: Any, last: bool = False) -> None: - """ - Send a message. - """ - # Serialize the message - serialized_message = self._serialization.serialize(message) - - # Encode the message - encode_message = self._tri_encoder.encode(serialized_message) - self._stream.send_data(encode_message, end_stream=last) - - def send_reset(self, error_code: Http2ErrorCode) -> None: - """ - Send a reset. - """ - self._stream.send_reset(error_code) - - class Listener(ClientCall.Listener): - """ - The listener of the triple client call. - """ - - def __init__(self, result: AbstractTriResult, serialization: Serialization): - self._result = result - self._serialization = serialization - - def on_message(self, message: Any) -> None: - """ - Called when a message is received. - """ - # Deserialize the message - deserialized_message = self._serialization.deserialize(message) - self._result.set_value(deserialized_message) - - def on_close( - self, rpc_status: TriRpcStatus, trailers: List[Tuple[str, str]] - ) -> None: - """ - Called when the stream is closed. - """ - if rpc_status.cause: - self._result.set_exception(rpc_status.cause) - # Notify the result that the stream is complete - self._result.set_value(self._result.END_SIGNAL) diff --git a/dubbo/protocol/triple/client/stream_listener.py b/dubbo/protocol/triple/client/stream_listener.py deleted file mode 100644 index f757afb..0000000 --- a/dubbo/protocol/triple/client/stream_listener.py +++ /dev/null @@ -1,108 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Optional - -from dubbo.compressor.compression import Compression -from dubbo.logger.logger_factory import loggerFactory -from dubbo.protocol.triple.client.calls import ClientCall -from dubbo.protocol.triple.tri_codec import TriDecoder -from dubbo.protocol.triple.tri_constants import TripleHeaderName, TripleHeaderValue -from dubbo.protocol.triple.tri_status import TriRpcCode, TriRpcStatus -from dubbo.remoting.aio.http2.headers import Http2Headers -from dubbo.remoting.aio.http2.registries import Http2ErrorCode -from dubbo.remoting.aio.http2.stream import StreamListener - -logger = loggerFactory.get_logger(__name__) - - -class _TriDecoderListener(TriDecoder.Listener): - """ - Triple decoder listener. - """ - - def __init__(self, listener: ClientCall.Listener): - self._listener = listener - self._rpc_status = None - self._trailers = None - - def add_rpc_status(self, status: TriRpcStatus): - self._rpc_status = status - - def add_trailers(self, trailers: list): - self._trailers = trailers - - def on_message(self, message: Any) -> None: - self._listener.on_message(message) - - def close(self): - self._listener.on_close(self._rpc_status, self._trailers) - - -class TriClientStreamListener(StreamListener): - """ - Stream listener for triple client. - """ - - def __init__( - self, listener: ClientCall.Listener, compression: Optional[Compression] = None - ): - super().__init__() - self._tri_decoder_listener = _TriDecoderListener(listener) - self._tri_decoder = TriDecoder(self._tri_decoder_listener, compression) - - def on_headers(self, headers: Http2Headers, end_stream: bool) -> None: - # validate headers - validated = True - if headers.status != "200": - # Illegal response code - validated = False - logger.error(f"Invalid response code: {headers.status}") - if content_type := headers.get(TripleHeaderName.CONTENT_TYPE.value): - # Invalid content type - if not content_type.startswith(TripleHeaderValue.APPLICATION_GRPC.value): - validated = False - logger.error( - f"Invalid content type: {headers.get(TripleHeaderName.CONTENT_TYPE.value)}" - ) - else: - # Missing content type - validated = False - logger.error("Missing content type") - - if not validated: - # TODO channel by local - pass - - def on_data(self, data: bytes, end_stream: bool) -> None: - # Decode the data - self._tri_decoder.decode(data) - if end_stream: - self._tri_decoder.close() - - def on_trailers(self, headers: Http2Headers) -> None: - tri_status = TriRpcStatus( - TriRpcCode.from_code(int(headers.get(TripleHeaderName.GRPC_STATUS.value))), - description=headers.get(TripleHeaderName.GRPC_MESSAGE.value), - ) - trailers = headers.to_list() - - self._tri_decoder_listener.add_rpc_status(tri_status) - self._tri_decoder_listener.add_trailers(trailers) - - self._tri_decoder.close() - - def on_reset(self, error_code: Http2ErrorCode) -> None: - pass diff --git a/dubbo/protocol/triple/tri_codec.py b/dubbo/protocol/triple/coders.py similarity index 56% rename from dubbo/protocol/triple/tri_codec.py rename to dubbo/protocol/triple/coders.py index 7cd227b..994bd6f 100644 --- a/dubbo/protocol/triple/tri_codec.py +++ b/dubbo/protocol/triple/coders.py @@ -13,13 +13,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import abc import struct from typing import Optional -from dubbo.compressor.compression import Compression +from dubbo.compression import Compressor, Decompressor +from dubbo.protocol.triple.exceptions import RpcError """ - gRPC Message Format Diagram + gRPC Message Format Diagram (HTTP/2 Data Frame): +----------------------+-------------------------+------------------+ | HTTP Header | gRPC Header | Business Data | +----------------------+-------------------------+------------------+ @@ -28,6 +31,8 @@ +----------------------+-------------------------+------------------+ """ +__all__ = ["TriEncoder", "TriDecoder"] + HEADER: str = "HEADER" PAYLOAD: str = "PAYLOAD" @@ -35,42 +40,75 @@ HEADER_LENGTH: int = 5 COMPRESSED_FLAG_MASK: int = 1 RESERVED_MASK = 0xFE +DEFAULT_MAX_MESSAGE_SIZE: int = 4194304 # 4MB class TriEncoder: """ This class is responsible for encoding the gRPC message format, which is composed of a header and payload. - - Args: - compression (Optional[Compression]): The Compression to use for compressing or decompressing the payload. """ - HEADER_LENGTH: int = 5 - COMPRESSED_FLAG_MASK: int = 1 + __slots__ = ["_compressor"] + + def __init__(self, compressor: Optional[Compressor]): + """ + Initialize the encoder. + :param compressor: The compression to use for compressing the payload. + :type compressor: Optional[Compressor] + """ + self._compressor = compressor + + @property + def compressor(self) -> Optional[Compressor]: + """ + Get the compressor. + :return: The compressor. + :rtype: Optional[Compressor] + """ + return self._compressor - def __init__(self, compression: Optional[Compression]): - self._compression = compression + @compressor.setter + def compressor(self, value: Compressor) -> None: + """ + Set the compressor. + :param value: The compressor. + :type value: Compressor + """ + self._compressor = value - def encode(self, message: bytes) -> bytes: + def encode(self, message: bytes, compress_flag: int) -> bytes: """ Encode the message into the gRPC message format. - Args: - message (bytes): The message to encode. - Returns: - bytes: The encoded message in gRPC format. + :param message: The message to encode. + :type message: bytes + :param compress_flag: The compress flag. 0 for no compression, 1 for compression. + :type compress_flag: int + :return: The encoded message. + :rtype: bytes """ - compressed_flag = COMPRESSED_FLAG_MASK if self._compression else 0 - if self._compression: - # Compress the payload - message = self._compression.compress(message) - message_length = len(message) - if message_length > 0xFFFFFFFF: - raise ValueError("Message too large to encode") + # check compress_flag + if compress_flag not in [0, 1]: + raise RpcError(f"compress_flag must be 0 or 1, but got {compress_flag}") + + # check message size + if len(message) > DEFAULT_MAX_MESSAGE_SIZE: + raise RpcError( + f"Message too large. Allowed maximum size is 4194304 bytes, but got {len(message)} bytes." + ) - # Create the header - header = struct.pack(">BI", compressed_flag, message_length) + # check compress_flag and compress the payload + if compress_flag == 1: + if not self._compressor: + raise RpcError("compression is required when compress_flag is 1") + message = self._compressor.compress(message) + + # Create the gRPC header + # >: big-endian + # B: unsigned char(1 byte) -> compressed_flag + # I: unsigned int(4 bytes) -> message_length + header = struct.pack(">BI", compress_flag, len(message)) return header + message @@ -78,21 +116,37 @@ def encode(self, message: bytes) -> bytes: class TriDecoder: """ This class is responsible for decoding the gRPC message format, which is composed of a header and payload. - - Args: - listener (TriDecoder.Listener): The listener to deliver the decoded payload to. - compression (Optional[Compression]): The Compression to use for compressing or decompressing the payload. """ + __slots__ = [ + "_accumulate", + "_listener", + "_decompressor", + "_state", + "_required_length", + "_decoding", + "_compressed", + "_closing", + "_closed", + ] + def __init__( self, listener: "TriDecoder.Listener", - compression: Optional[Compression], + decompressor: Optional[Decompressor], ): + """ + Initialize the decoder. + :param decompressor: The decompressor to use for decompressing the payload. + :type decompressor: Optional[Decompressor] + :param listener: The listener to deliver the decoded payload to when a message is received. + :type listener: TriDecoder.Listener + """ + + self._listener = listener # store data for decoding self._accumulate = bytearray() - self._listener = listener - self._compression = compression + self._decompressor = decompressor self._state = HEADER self._required_length = HEADER_LENGTH @@ -109,6 +163,8 @@ def __init__( def decode(self, data: bytes) -> None: """ Process the incoming bytes, decoding the gRPC message and delivering the payload to the listener. + :param data: The data to decode. + :type data: bytes """ self._accumulate.extend(data) self._do_decode() @@ -145,6 +201,8 @@ def _do_decode(self) -> None: def _has_enough_bytes(self) -> bool: """ Check if the accumulated bytes are enough to process the header or payload + :return: True if there are enough bytes, False otherwise. + :rtype: bool """ return len(self._accumulate) >= self._required_length @@ -154,15 +212,16 @@ def _process_header(self) -> None: """ header_bytes = self._accumulate[: self._required_length] self._accumulate = self._accumulate[self._required_length :] + # Parse the header - compressed_flag = header_bytes[0] + compressed_flag = int(header_bytes[0]) if (compressed_flag & RESERVED_MASK) != 0: - raise ValueError("gRPC frame header malformed: reserved bits not zero") - - self._compressed = bool(compressed_flag & COMPRESSED_FLAG_MASK) - self._required_length = int.from_bytes(header_bytes[1:], byteorder="big") - # Continue to process the payload - self._state = PAYLOAD + raise RpcError("gRPC frame header malformed: reserved bits not zero") + else: + self._compressed = bool(compressed_flag & COMPRESSED_FLAG_MASK) + self._required_length = int.from_bytes(header_bytes[1:], byteorder="big") + # Continue to process the payload + self._state = PAYLOAD def _process_payload(self) -> None: """ @@ -173,7 +232,7 @@ def _process_payload(self) -> None: if self._compressed: # Decompress the payload - payload_bytes = self._compression.decompress(payload_bytes) + payload_bytes = self._decompressor.decompress(payload_bytes) self._listener.on_message(bytes(payload_bytes)) @@ -181,15 +240,20 @@ def _process_payload(self) -> None: self._required_length = HEADER_LENGTH self._state = HEADER - class Listener: + class Listener(abc.ABC): + + @abc.abstractmethod def on_message(self, message: bytes): """ Called when a message is received. + :param message: The message received. + :type message: bytes """ - raise NotImplementedError("Listener.on_message() not implemented") + raise NotImplementedError() + @abc.abstractmethod def close(self): """ Called when the listener is closed. """ - raise NotImplementedError("Listener.close() not implemented") + raise NotImplementedError() diff --git a/dubbo/protocol/triple/tri_status.py b/dubbo/protocol/triple/constants.py similarity index 75% rename from dubbo/protocol/triple/tri_status.py rename to dubbo/protocol/triple/constants.py index c767c24..a51244e 100644 --- a/dubbo/protocol/triple/tri_status.py +++ b/dubbo/protocol/triple/constants.py @@ -13,11 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import enum -from typing import Optional -class TriRpcCode(enum.Enum): +class GRpcCode(enum.Enum): """ RPC status codes. See https://github.com/grpc/grpc/blob/master/doc/statuscodes.md @@ -75,7 +75,7 @@ class TriRpcCode(enum.Enum): UNAUTHENTICATED = 16 @classmethod - def from_code(cls, code: int) -> "TriRpcCode": + def from_code(cls, code: int) -> "GRpcCode": """ Get the RPC status code from the given code. Args: @@ -87,24 +87,36 @@ def from_code(cls, code: int) -> "TriRpcCode": return cls.UNKNOWN -class TriRpcStatus: +class TripleHeaderName(enum.Enum): + """ + Header names used in triple protocol. + """ + + CONTENT_TYPE = "content-type" + + TE = "te" + GRPC_STATUS = "grpc-status" + GRPC_MESSAGE = "grpc-message" + GRPC_STATUS_DETAILS_BIN = "grpc-status-details-bin" + GRPC_TIMEOUT = "grpc-timeout" + GRPC_ENCODING = "grpc-encoding" + GRPC_ACCEPT_ENCODING = "grpc-accept-encoding" + + SERVICE_VERSION = "tri-service-version" + SERVICE_GROUP = "tri-service-group" + + CONSUMER_APP_NAME = "tri-consumer-appname" + + +class TripleHeaderValue(enum.Enum): """ - RPC status. - Args: - code: RPC status code. - cause: Optional exception that caused the RPC status. - description: Optional description of the RPC status. + Header values used in triple protocol. """ - def __init__( - self, - code: TriRpcCode, - cause: Optional[Exception] = None, - description: Optional[str] = None, - ): - self.code = code - self.cause = cause - self.description = description - - def __repr__(self): - return f"TriRpcStatus(code={self.code}, cause={self.cause}, description={self.description})" + TRAILERS = "trailers" + HTTP = "http" + HTTPS = "https" + APPLICATION_GRPC_PROTO = "application/grpc+proto" + APPLICATION_GRPC = "application/grpc" + + TEXT_PLAIN_UTF8 = "text/plain; encoding=utf-8" diff --git a/dubbo/protocol/triple/tri_constants.py b/dubbo/protocol/triple/exceptions.py similarity index 57% rename from dubbo/protocol/triple/tri_constants.py rename to dubbo/protocol/triple/exceptions.py index 34e3120..6dbfcb9 100644 --- a/dubbo/protocol/triple/tri_constants.py +++ b/dubbo/protocol/triple/exceptions.py @@ -13,32 +13,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import enum +__all__ = ["RpcError", "StatusRpcError"] -class TripleHeaderName(enum.Enum): + +class RpcError(Exception): """ - Header names used in triple protocol. + The RPC exception. """ - CONTENT_TYPE = "content-type" + def __init__(self, message: str): + self.message = f"RPC Invocation failed: {message}" + super().__init__(self.message) - TE = "te" - GRPC_STATUS = "grpc-status" - GRPC_MESSAGE = "grpc-message" - GRPC_STATUS_DETAILS_BIN = "grpc-status-details-bin" - GRPC_TIMEOUT = "grpc-timeout" - GRPC_ENCODING = "grpc-encoding" - GRPC_ACCEPT_ENCODING = "grpc-accept-encoding" + def __str__(self): + return self.message -class TripleHeaderValue(enum.Enum): +class StatusRpcError(Exception): """ - Header values used in triple protocol. + The status RPC exception. """ - TRAILERS = "trailers" - HTTP = "http" - HTTPS = "https" - APPLICATION_GRPC_PROTO = "application/grpc+proto" - APPLICATION_GRPC = "application/grpc" + def __init__(self, status): + self.status = status + self.message = f"RPC Invocation failed: {status.code} {status.description}" + super().__init__(status, self.message) + + def __str__(self): + return self.message diff --git a/dubbo/protocol/triple/invoker.py b/dubbo/protocol/triple/invoker.py new file mode 100644 index 0000000..d835036 --- /dev/null +++ b/dubbo/protocol/triple/invoker.py @@ -0,0 +1,215 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dubbo.common import constants as common_constants +from dubbo.common.url import URL +from dubbo.compression import Compressor, Identity +from dubbo.extension import ExtensionError, extensionLoader +from dubbo.logger import loggerFactory +from dubbo.protocol import Invoker, Result +from dubbo.protocol.invocation import Invocation, RpcInvocation +from dubbo.protocol.triple.call import TripleClientCall +from dubbo.protocol.triple.call.client_call import DefaultClientCallListener +from dubbo.protocol.triple.constants import TripleHeaderName, TripleHeaderValue +from dubbo.protocol.triple.metadata import RequestMetadata +from dubbo.protocol.triple.results import TriResult +from dubbo.remoting import Client +from dubbo.remoting.aio.http2.stream_handler import StreamClientMultiplexHandler +from dubbo.serialization import ( + CustomDeserializer, + CustomSerializer, + DirectDeserializer, + DirectSerializer, +) + +__all__ = ["TripleInvoker"] + +_LOGGER = loggerFactory.get_logger(__name__) + + +class TripleInvoker(Invoker): + """ + Triple invoker. + """ + + __slots__ = ["_url", "_client", "_stream_multiplexer", "_compression", "_destroyed"] + + def __init__( + self, url: URL, client: Client, stream_multiplexer: StreamClientMultiplexHandler + ): + self._url = url + self._client = client + self._stream_multiplexer = stream_multiplexer + + self._destroyed = False + + def invoke(self, invocation: RpcInvocation) -> Result: + call_type = invocation.get_attribute(common_constants.CALL_KEY) + result = TriResult(call_type) + + if not self._client.is_connected(): + # Reconnect the client + self._client.reconnect() + + # get serializer + serializer = DirectSerializer() + serializing_function = invocation.get_attribute(common_constants.SERIALIZER_KEY) + if serializing_function: + serializer = CustomSerializer(serializing_function) + + # get deserializer + deserializer = DirectDeserializer() + deserializing_function = invocation.get_attribute( + common_constants.DESERIALIZER_KEY + ) + if deserializing_function: + deserializer = CustomDeserializer(deserializing_function) + + # Create a new TriClientCall + tri_client_call = TripleClientCall( + self._stream_multiplexer, + DefaultClientCallListener(result), + serializer, + deserializer, + ) + + # start the call + try: + metadata = self._create_metadata(invocation) + tri_client_call.start(metadata) + except ExtensionError as e: + result.set_exception(e) + return result + + # invoke + if call_type in ( + common_constants.UNARY_CALL_VALUE, + common_constants.SERVER_STREAM_CALL_VALUE, + ): + self._invoke_unary(tri_client_call, invocation) + elif call_type in ( + common_constants.CLIENT_STREAM_CALL_VALUE, + common_constants.BI_STREAM_CALL_VALUE, + ): + self._invoke_stream(tri_client_call, invocation) + + return result + + def _invoke_unary(self, call: TripleClientCall, invocation: Invocation) -> None: + """ + Invoke a unary call. + :param call: The call to invoke. + :type call: TripleClientCall + :param invocation: The invocation to invoke. + :type invocation: Invocation + """ + try: + argument = invocation.get_argument() + if callable(argument): + argument = argument() + except Exception as e: + _LOGGER.exception(f"Invoke failed: {str(e)}", e) + call.cancel_by_local(e) + return + + # send the message + call.send_message(argument, last=True) + + def _invoke_stream(self, call: TripleClientCall, invocation: Invocation) -> None: + """ + Invoke a stream call. + :param call: The call to invoke. + :type call: TripleClientCall + :param invocation: The invocation to invoke. + :type invocation: Invocation + """ + try: + # get the argument + argument = invocation.get_argument() + iterator = argument() if callable(argument) else argument + + # send the messages + BEGIN_SIGNAL = object() + next_message = BEGIN_SIGNAL + for message in iterator: + if next_message is not BEGIN_SIGNAL: + call.send_message(next_message, last=False) + next_message = message + next_message = next_message if next_message is not BEGIN_SIGNAL else None + call.send_message(next_message, last=True) + except Exception as e: + _LOGGER.exception(f"Invoke failed: {str(e)}", e) + call.cancel_by_local(e) + + def _create_metadata(self, invocation: Invocation) -> RequestMetadata: + """ + Create the metadata. + :param invocation: The invocation. + :type invocation: Invocation + :return: The metadata. + :rtype: RequestMetadata + :raise ExtensionError: If the compressor is not supported. + """ + metadata = RequestMetadata() + # set service and method + metadata.service = invocation.get_service_name() + metadata.method = invocation.get_method_name() + + # get scheme + metadata.scheme = ( + TripleHeaderValue.HTTPS.value + if self._url.parameters.get(common_constants.SSL_ENABLED_KEY, False) + else TripleHeaderValue.HTTP.value + ) + + # get compressor + compression = self._url.parameters.get( + common_constants.COMPRESSION_KEY, Identity.get_message_encoding() + ) + if metadata.compressor.get_message_encoding() != compression: + try: + metadata.compressor = extensionLoader.get_extension( + Compressor, compression + )() + except ExtensionError as e: + _LOGGER.error(f"Unsupported compression: {compression}") + raise e + + # get address + metadata.address = self._url.location + + # TODO add more metadata + metadata.attachments[TripleHeaderName.TE.value] = ( + TripleHeaderValue.TRAILERS.value + ) + + return metadata + + def get_url(self) -> URL: + return self._url + + def is_available(self) -> bool: + return self._client.is_connected() + + @property + def destroyed(self) -> bool: + return self._destroyed + + def destroy(self) -> None: + self._client.close() + self._client = None + self._stream_multiplexer = None + self._url = None diff --git a/dubbo/protocol/triple/metadata.py b/dubbo/protocol/triple/metadata.py new file mode 100644 index 0000000..974277b --- /dev/null +++ b/dubbo/protocol/triple/metadata.py @@ -0,0 +1,95 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +from dubbo.compression import Compressor, Identity +from dubbo.protocol.triple.constants import TripleHeaderName, TripleHeaderValue +from dubbo.remoting.aio.http2.headers import Http2Headers, HttpMethod + + +class RequestMetadata: + """ + The request metadata. + """ + + def __init__(self): + self.scheme: Optional[str] = None + self.application: Optional[str] = None + self.service: Optional[str] = None + self.version: Optional[str] = None + self.group: Optional[str] = None + self.address: Optional[str] = None + self.acceptEncoding: Optional[str] = None + self.timeout: Optional[str] = None + self.compressor: Compressor = Identity() + self.method: Optional[str] = None + self.attachments: Dict[str, Any] = {} + + def to_headers(self) -> Http2Headers: + """ + Convert to HTTP/2 headers. + :return: The HTTP/2 headers. + :rtype: Http2Headers + """ + headers = Http2Headers() + headers.scheme = self.scheme + headers.authority = self.address + headers.method = HttpMethod.POST.value + headers.path = f"/{self.service}/{self.method}" + headers.add( + TripleHeaderName.CONTENT_TYPE.value, + TripleHeaderValue.APPLICATION_GRPC_PROTO.value, + ) + + if self.version != "1.0.0": + set_if_not_none( + headers, TripleHeaderName.SERVICE_VERSION.value, self.version + ) + + set_if_not_none(headers, TripleHeaderName.GRPC_TIMEOUT.value, self.timeout) + set_if_not_none(headers, TripleHeaderName.SERVICE_GROUP.value, self.group) + set_if_not_none( + headers, TripleHeaderName.CONSUMER_APP_NAME.value, self.application + ) + set_if_not_none( + headers, TripleHeaderName.GRPC_ENCODING.value, self.acceptEncoding + ) + + if self.compressor.get_message_encoding() != Identity.get_message_encoding(): + set_if_not_none( + headers, + TripleHeaderName.GRPC_ENCODING.value, + self.compressor.get_message_encoding(), + ) + + [headers.add(k, str(v)) for k, v in self.attachments.items()] + + return headers + + +def set_if_not_none(headers: Http2Headers, key: str, value: Optional[str]) -> None: + """ + Set the header if the value is not None. + :param headers: The headers. + :type headers: Http2Headers + :param key: The key. + :type key: str + :param value: The value. + :type value: Optional[str] + """ + if value: + headers.add(key, str(value)) diff --git a/dubbo/protocol/triple/protocol.py b/dubbo/protocol/triple/protocol.py new file mode 100644 index 0000000..9347fc8 --- /dev/null +++ b/dubbo/protocol/triple/protocol.py @@ -0,0 +1,106 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, Optional + +from dubbo.common import constants as common_constants +from dubbo.common.url import URL +from dubbo.extension import extensionLoader +from dubbo.logger import loggerFactory +from dubbo.protocol import Invoker, Protocol +from dubbo.protocol.triple.invoker import TripleInvoker +from dubbo.protocol.triple.stream.server_stream import ServerTransportListener +from dubbo.proxy.handlers import RpcServiceHandler +from dubbo.remoting import Server, Transporter +from dubbo.remoting.aio import constants as aio_constants +from dubbo.remoting.aio.http2.protocol import Http2Protocol +from dubbo.remoting.aio.http2.stream_handler import ( + StreamClientMultiplexHandler, + StreamServerMultiplexHandler, +) + +_LOGGER = loggerFactory.get_logger(__name__) + + +class TripleProtocol(Protocol): + """ + Triple protocol. + """ + + __slots__ = ["_url", "_transporter", "_invokers"] + + def __init__(self, url: URL): + self._url = url + self._transporter: Transporter = extensionLoader.get_extension( + Transporter, + self._url.parameters.get( + common_constants.TRANSPORTER_KEY, + common_constants.TRANSPORTER_DEFAULT_VALUE, + ), + )() + self._invokers = [] + self._server: Optional[Server] = None + + self._path_resolver: Dict[str, RpcServiceHandler] = {} + + def export(self, url: URL): + """ + Export a service. + """ + if self._server is not None: + return + + service_handler: RpcServiceHandler = url.attributes[ + common_constants.SERVICE_HANDLER_KEY + ] + + self._path_resolver[service_handler.service_name] = service_handler + + def listener_factory(_path_resolver): + return ServerTransportListener(_path_resolver) + + fn = functools.partial(listener_factory, self._path_resolver) + + # Create a stream handler + executor = ThreadPoolExecutor(thread_name_prefix="dubbo-tri-") + stream_multiplexer = StreamServerMultiplexHandler(fn, executor) + # set stream handler and protocol + url.attributes[aio_constants.STREAM_HANDLER_KEY] = stream_multiplexer + url.attributes[common_constants.PROTOCOL_KEY] = Http2Protocol + + # Create a server + self._server = self._transporter.bind(url) + + def refer(self, url: URL) -> Invoker: + """ + Refer a remote service. + Args: + url (URL): The URL of the remote service. + """ + executor = ThreadPoolExecutor(thread_name_prefix="dubbo-tri-") + # Create a stream handler + stream_multiplexer = StreamClientMultiplexHandler(executor) + # set stream handler and protocol + url.attributes[aio_constants.STREAM_HANDLER_KEY] = stream_multiplexer + url.attributes[common_constants.PROTOCOL_KEY] = Http2Protocol + + # Create a client + client = self._transporter.connect(url) + invoker = TripleInvoker(url, client, stream_multiplexer) + self._invokers.append(invoker) + return invoker diff --git a/dubbo/protocol/triple/tri_results.py b/dubbo/protocol/triple/results.py similarity index 51% rename from dubbo/protocol/triple/tri_results.py rename to dubbo/protocol/triple/results.py index 62d4a27..c91a22b 100644 --- a/dubbo/protocol/triple/tri_results.py +++ b/dubbo/protocol/triple/results.py @@ -13,70 +13,63 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import queue -from typing import Any, Dict, Optional -from dubbo.constants.common_constants import CALL_CLIENT_STREAM, CALL_UNARY -from dubbo.protocol.result import Result +from typing import Any +from dubbo.common import constants as common_constants +from dubbo.common.deliverers import MultiMessageDeliverer, SingleMessageDeliverer +from dubbo.protocol import Result -class AbstractTriResult(Result): - """ - The abstract result. - """ - - END_SIGNAL = object() - - def __init__(self, call_type: str): - self.call_type = call_type - self._exception: Optional[Exception] = None - self._attachments: Dict[str, Any] = {} - - def set_exception(self, exception: Exception) -> None: - self._exception = exception - - def exception(self) -> Exception: - return self._exception - - def add_attachment(self, key: str, value: Any) -> None: - self._attachments[key] = value - def get_attachment(self, key: str) -> Any: - return self._attachments.get(key) - - -class TriResult(AbstractTriResult): +class TriResult(Result): """ The triple result. """ def __init__(self, call_type: str): - super().__init__(call_type) - self._values = queue.Queue() + self._streamed = True + if call_type in [ + common_constants.UNARY_CALL_VALUE, + common_constants.CLIENT_STREAM_CALL_VALUE, + ]: + self._streamed = False + + self._deliverer = ( + MultiMessageDeliverer() if self._streamed else SingleMessageDeliverer() + ) + + self._exception = None def set_value(self, value: Any) -> None: """ Set the value. """ - self._values.put(value) + self._deliverer.add(value) + + def complete_value(self) -> None: + """ + Complete the value. + """ + self._deliverer.complete() def value(self) -> Any: """ Get the value. """ - if self.call_type in [CALL_UNARY, CALL_CLIENT_STREAM]: - return self._get_single_value() + if self._streamed: + return self._deliverer else: - return self._iterating_values() + return self._deliverer.get() - def _get_single_value(self) -> Any: + def set_exception(self, exception: Exception) -> None: """ - Get the single value. + Set the exception. """ - return value if (value := self._values.get()) is not self.END_SIGNAL else None + self._exception = exception + self._deliverer.cancel(exception) - def _iterating_values(self) -> Any: + def exception(self) -> Exception: """ - Iterate the values. + Get the exception. """ - return iter(lambda: self._values.get(), self.END_SIGNAL) + return self._exception diff --git a/dubbo/protocol/triple/status.py b/dubbo/protocol/triple/status.py new file mode 100644 index 0000000..6e31790 --- /dev/null +++ b/dubbo/protocol/triple/status.py @@ -0,0 +1,152 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +from dubbo.protocol.triple.constants import GRpcCode +from dubbo.protocol.triple.exceptions import StatusRpcError +from dubbo.remoting.aio.http2.registries import HttpStatus + + +class TriRpcStatus: + """ + RPC status. + """ + + __slots__ = ["_code", "_cause", "_description"] + + def __init__( + self, + code: GRpcCode, + cause: Optional[Exception] = None, + description: Optional[str] = None, + ): + """ + Initialize the RPC status. + :param code: The RPC status code. + :type code: TriRpcCode + :param description: The description. + :type description: Optional[str] + :param cause: The exception cause. + :type cause: Optional[Exception] + """ + if isinstance(code, int): + code = GRpcCode.from_code(code) + self._code = code + self._description = description + self._cause = cause + + @property + def code(self) -> GRpcCode: + return self._code + + @property + def description(self) -> Optional[str]: + return self._description + + @property + def cause(self) -> Optional[Exception]: + return self._cause + + def with_description(self, description: str) -> "TriRpcStatus": + """ + Set the description. + :param description: The description. + :type description: str + :return: The RPC status. + :rtype: TriRpcStatus + """ + self._description = description + return self + + def with_cause(self, cause: Exception) -> "TriRpcStatus": + """ + Set the cause. + :param cause: The cause. + :type cause: Exception + :return: The RPC status. + :rtype: TriRpcStatus + """ + self._cause = cause + return self + + def append_description(self, description: str) -> None: + """ + Append the description. + :param description: The description to append. + :type description: str + """ + if self._description: + self._description += f"\n{description}" + else: + self._description = description + + def as_exception(self) -> Exception: + """ + Convert the RPC status to an exception. + :return: The exception. + :rtype: Exception + """ + return StatusRpcError(self) + + @staticmethod + def limit_desc(description: str, limit: int = 1024) -> str: + """ + Limit the description length. + :param description: The description. + :type description: str + :param limit: The limit.(default: 1024) + :type limit: int + :return: The limited description. + :rtype: str + """ + if description and len(description) > limit: + return f"{description[:limit]}..." + return description + + @classmethod + def from_rpc_code(cls, code: Union[int, GRpcCode]): + if isinstance(code, int): + code = GRpcCode.from_code(code) + return cls(code) + + @classmethod + def from_http_code(cls, code: Union[int, HttpStatus]): + http_status = HttpStatus.from_code(code) if isinstance(code, int) else code + rpc_code = GRpcCode.UNKNOWN + if HttpStatus.is_1xx(http_status) or http_status in [ + HttpStatus.BAD_REQUEST, + HttpStatus.REQUEST_HEADER_FIELDS_TOO_LARGE, + ]: + rpc_code = GRpcCode.INTERNAL + elif http_status == HttpStatus.UNAUTHORIZED: + rpc_code = GRpcCode.UNAUTHENTICATED + elif http_status == HttpStatus.FORBIDDEN: + rpc_code = GRpcCode.PERMISSION_DENIED + elif http_status == HttpStatus.NOT_FOUND: + rpc_code = GRpcCode.NOT_FOUND + elif http_status in [ + HttpStatus.BAD_GATEWAY, + HttpStatus.TOO_MANY_REQUESTS, + HttpStatus.SERVICE_UNAVAILABLE, + HttpStatus.GATEWAY_TIMEOUT, + ]: + rpc_code = GRpcCode.UNAVAILABLE + + return cls(rpc_code) + + def __repr__(self): + return f"TriRpcStatus(code={self._code}, cause={self._cause}, description={self._description})" diff --git a/dubbo/client/__init__.py b/dubbo/protocol/triple/stream/__init__.py similarity index 88% rename from dubbo/client/__init__.py rename to dubbo/protocol/triple/stream/__init__.py index bcba37a..5dc8c8f 100644 --- a/dubbo/client/__init__.py +++ b/dubbo/protocol/triple/stream/__init__.py @@ -13,3 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from ._interfaces import ClientStream, ServerStream + +__all__ = ["ClientStream", "ServerStream"] diff --git a/dubbo/protocol/triple/stream/_interfaces.py b/dubbo/protocol/triple/stream/_interfaces.py new file mode 100644 index 0000000..369fd07 --- /dev/null +++ b/dubbo/protocol/triple/stream/_interfaces.py @@ -0,0 +1,167 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Any, Dict + +from dubbo.protocol.triple.status import TriRpcStatus +from dubbo.remoting.aio.http2.headers import Http2Headers + +__all__ = ["Stream", "ClientStream", "ServerStream"] + + +class Stream(abc.ABC): + """ + Stream is a bidirectional channel that manipulates the data flow between peers. + Inbound data from remote peer is acquired by Stream.Listener. + Outbound data to remote peer is sent directly by Stream + """ + + @abc.abstractmethod + def send_headers(self, headers: Http2Headers) -> None: + """ + Send headers to remote peer + :param headers: The headers to send + :type headers: Http2Headers + """ + raise NotImplementedError() + + @abc.abstractmethod + def cancel_by_local(self, status: TriRpcStatus) -> None: + """ + Cancel the stream by local + :param status: The status + :type status: TriRpcStatus + """ + raise NotImplementedError() + + class Listener(abc.ABC): + """ + Listener is a callback interface that receives events on the stream. + """ + + @abc.abstractmethod + def on_message(self, data: bytes) -> None: + """ + Called when data is received. + :param data: The data received + :type data: bytes + """ + raise NotImplementedError() + + @abc.abstractmethod + def on_cancel_by_remote(self, status: TriRpcStatus) -> None: + """ + Called when the stream is cancelled by remote + :param status: The status + :type status: TriRpcStatus + """ + raise NotImplementedError() + + +class ClientStream(Stream, abc.ABC): + """ + ClientStream is used to send request to server and receive response from server. + """ + + @abc.abstractmethod + def send_message(self, data: bytes, compress_flag: int, last: bool) -> None: + """ + Send message to remote peer + :param data: The message data + :type data: bytes + :param compress_flag: The compress flag (0: no compress, 1: compress) + :type compress_flag: int + :param last: Whether this is the last message + :type last: bool + """ + raise NotImplementedError() + + class Listener(Stream.Listener, abc.ABC): + """ + Listener is a callback interface that receives events on the stream. + """ + + @abc.abstractmethod + def on_complete( + self, status: TriRpcStatus, attachments: Dict[str, Any] + ) -> None: + """ + Called when the stream is completed. + :param status: The status + :type status: TriRpcStatus + :param attachments: The attachments + :type attachments: Dict[str,Any] + """ + raise NotImplementedError() + + +class ServerStream(Stream, abc.ABC): + """ + ServerStream is used to receive request from client and send response to client. + """ + + @abc.abstractmethod + def set_compression(self, compression: str) -> None: + """ + Set the compression. + :param compression: The compression + :type compression: str + """ + raise NotImplementedError() + + @abc.abstractmethod + def send_message(self, data: bytes, compress_flag: bool) -> None: + """ + Send message to remote peer + :param data: The message data + :type data: bytes + :param compress_flag: The compress flag + :type compress_flag: bool + """ + raise NotImplementedError() + + @abc.abstractmethod + def complete(self, status: TriRpcStatus, attachments: Dict[str, Any]) -> None: + """ + Complete the stream + :param status: The status + :type status: TriRpcStatus + :param attachments: The attachments + :type attachments: Dict[str,Any] + """ + raise NotImplementedError() + + class Listener(Stream.Listener, abc.ABC): + """ + Listener is a callback interface that receives events on the stream. + """ + + @abc.abstractmethod + def on_headers(self, headers: Dict[str, Any]) -> None: + """ + Called when headers are received. + :param headers: The headers + :type headers: Http2Headers + """ + raise NotImplementedError() + + @abc.abstractmethod + def on_complete(self) -> None: + """ + Callback when no more data from client side + """ + raise NotImplementedError() diff --git a/dubbo/protocol/triple/stream/client_stream.py b/dubbo/protocol/triple/stream/client_stream.py new file mode 100644 index 0000000..3aef898 --- /dev/null +++ b/dubbo/protocol/triple/stream/client_stream.py @@ -0,0 +1,312 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from dubbo.compression import Compressor, Decompressor +from dubbo.compression.identities import Identity +from dubbo.extension import ExtensionError, extensionLoader +from dubbo.protocol.triple.coders import TriDecoder, TriEncoder +from dubbo.protocol.triple.constants import ( + GRpcCode, + TripleHeaderName, + TripleHeaderValue, +) +from dubbo.protocol.triple.status import TriRpcStatus +from dubbo.protocol.triple.stream import ClientStream +from dubbo.remoting.aio.http2.headers import Http2Headers +from dubbo.remoting.aio.http2.registries import Http2ErrorCode +from dubbo.remoting.aio.http2.stream import Http2Stream + +__all__ = ["TriClientStream"] + + +class TriClientStream(ClientStream): + """ + Triple client stream. + """ + + def __init__( + self, + listener: ClientStream.Listener, + compressor: Optional[Compressor], + ): + """ + Initialize the triple client stream. + :param listener: The listener. + :type listener: ClientStream.Listener + :param compressor: The compression. + """ + self._transport_listener = ClientTransportListener(listener) + self._encoder = TriEncoder(compressor) + + self._stream: Optional[Http2Stream] = None + + @property + def transport_listener(self) -> "ClientTransportListener": + """ + Get the transport listener. + :return: The transport listener. + :rtype: ClientTransportListener + """ + return self._transport_listener + + def bind(self, stream: Http2Stream) -> None: + """ + Bind the stream. + :param stream: The stream to bind. + :type stream: Http2Stream + """ + self._stream = stream + + def send_headers(self, headers: Http2Headers) -> None: + """ + Send headers to remote peer. + :param headers: The headers to send. + :type headers: Http2Headers + """ + self._stream.send_headers(headers) + + def send_message(self, data: bytes, compress_flag: int, last: bool) -> None: + """ + Send message to remote peer. + :param data: The message data. + :type data: bytes + :param compress_flag: The compress flag (0: no compress, 1: compress). + :type compress_flag: int + :param last: Whether this is the last message. + :type last: bool + """ + # encode the data + encoded_data = self._encoder.encode(data, compress_flag) + self._stream.send_data(encoded_data, last) + + def cancel_by_local(self, status: TriRpcStatus) -> None: + """ + Cancel the stream by local + :param status: The status + :type status: TriRpcStatus + """ + self._stream.cancel_by_local(Http2ErrorCode.CANCEL) + self._transport_listener.rst = True + + +class ClientTransportListener(Http2Stream.Listener, TriDecoder.Listener): + """ + Client transport listener. + """ + + __slots__ = [ + "_listener", + "_decoder", + "_rpc_status", + "_headers_received", + "_rst", + ] + + def __init__(self, listener: ClientStream.Listener): + """ + Initialize the client transport listener. + :param listener: The listener. + """ + super().__init__() + self._listener = listener + + self._decoder: Optional[TriDecoder] = None + self._rpc_status: Optional[TriRpcStatus] = None + + self._headers_received = False + self._rst = False + + self._trailers: Http2Headers = Http2Headers() + + @property + def rst(self) -> bool: + """ + Whether the stream is rest. + :return: True if the stream is rest, otherwise False. + :rtype: bool + """ + return self._rst + + @rst.setter + def rst(self, value: bool) -> None: + """ + Set whether the stream is rest. + :param value: True if the stream is rest, otherwise False. + :type value: bool + """ + self._rst = value + + def on_headers(self, headers: Http2Headers, end_stream: bool) -> None: + if not end_stream: + # handle headers + self._on_headers_received(headers) + else: + # handle trailers + self._on_trailers_received(headers) + + if end_stream and not self._headers_received: + self._handle_transport_error(self._rpc_status) + + def on_data(self, data: bytes, end_stream: bool) -> None: + if self._rpc_status: + self._rpc_status.append_description(f"Data: {data.decode('utf-8')}") + if len(self._rpc_status.description) > 512 or end_stream: + self._handle_transport_error(self._rpc_status) + return + + # decode the data + self._decoder.decode(data) + + def cancel_by_remote(self, error_code: Http2ErrorCode) -> None: + self.rst = True + self._rpc_status = TriRpcStatus( + GRpcCode.CANCELLED, + description=f"Cancelled by remote peer, error code: {error_code}", + ) + self._listener.on_complete(self._rpc_status, self._trailers.to_dict()) + + def _on_headers_received(self, headers: Http2Headers) -> None: + """ + Handle the headers received. + :param headers: The headers. + :type headers: Http2Headers + """ + self._headers_received = True + + # validate headers + self._validate_headers(headers) + if self._rpc_status: + return + + # get messageEncoding + decompressor: Optional[Decompressor] = None + message_encoding = headers.get( + TripleHeaderName.GRPC_ENCODING.value, Identity.get_message_encoding() + ) + if message_encoding != Identity.get_message_encoding(): + try: + # get decompressor by messageEncoding + decompressor = extensionLoader.get_extension( + Decompressor, message_encoding + )() + except ExtensionError: + # unsupported + self._rpc_status = TriRpcStatus( + GRpcCode.UNIMPLEMENTED, + description="Unsupported message encoding", + ) + return + + self._decoder = TriDecoder(self, decompressor) + + def _validate_headers(self, headers: Http2Headers) -> None: + """ + Validate the headers. + :param headers: The headers. + :type headers: Http2Headers + """ + status_code = int(headers.status) if headers.status else None + if status_code: + content_type = headers.get(TripleHeaderName.CONTENT_TYPE.value, "") + if not content_type.startswith(TripleHeaderValue.APPLICATION_GRPC.value): + self._rpc_status = TriRpcStatus.from_http_code( + status_code + ).with_description(f"Invalid content type: {content_type}") + + else: + self._rpc_status = TriRpcStatus( + GRpcCode.INTERNAL, description="Missing HTTP status code" + ) + + def _on_trailers_received(self, trailers: Http2Headers) -> None: + """ + Handle the trailers received. + :param trailers: The trailers. + :type trailers: Http2Headers + """ + if not self._rpc_status and not self._headers_received: + self._validate_headers(trailers) + + if self._rpc_status: + self._rpc_status.append_description(f"Trailers: {trailers}") + else: + self._rpc_status = self._get_status_from_trailers(trailers) + self._trailers = trailers + + if self._decoder: + self._decoder.close() + else: + self._listener.on_complete(self._rpc_status, trailers.to_dict()) + + def _get_status_from_trailers(self, trailers: Http2Headers) -> TriRpcStatus: + """ + Validate the trailers. + :param trailers: The trailers. + :type trailers: Http2Headers + :return: The RPC status. + :rtype: TriRpcStatus + """ + grpc_status_code = int(trailers.get(TripleHeaderName.GRPC_STATUS.value, "-1")) + if grpc_status_code != -1: + status = TriRpcStatus.from_rpc_code(grpc_status_code) + message = trailers.get(TripleHeaderName.GRPC_MESSAGE.value, "") + status.append_description(message) + return status + + # If the status code is not found , something is broken. Try to provide a rational error. + if self._headers_received: + return TriRpcStatus( + GRpcCode.UNKNOWN, description="Missing GRPC status in response" + ) + + # Try to get status from headers + status_code = int(trailers.status) if trailers.status else None + if status_code is not None: + status = TriRpcStatus.from_http_code(status_code) + else: + status = TriRpcStatus( + GRpcCode.INTERNAL, description="Missing HTTP status code" + ) + + status.append_description( + "Missing GRPC status, please infer the error from the HTTP status code" + ) + return status + + def _handle_transport_error(self, transport_error: TriRpcStatus) -> None: + """ + Handle the transport error. + :param transport_error: The transport error. + :type transport_error: TriRpcStatus + """ + self._stream.cancel_by_local(Http2ErrorCode.NO_ERROR) + self.rst = True + self._listener.on_complete(transport_error, self._trailers.to_dict()) + + def on_message(self, message: bytes) -> None: + """ + Called when a message is received (TriDecoder.Listener callback). + :param message: The message received. + """ + self._listener.on_message(message) + + def close(self) -> None: + """ + Called when the stream is closed (TriDecoder.Listener callback). + """ + self._listener.on_complete(self._rpc_status, self._trailers.to_dict()) diff --git a/dubbo/protocol/triple/stream/server_stream.py b/dubbo/protocol/triple/stream/server_stream.py new file mode 100644 index 0000000..b642cfa --- /dev/null +++ b/dubbo/protocol/triple/stream/server_stream.py @@ -0,0 +1,325 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +from dubbo.compression import Decompressor +from dubbo.compression.identities import Identity +from dubbo.extension import ExtensionError, extensionLoader +from dubbo.logger import loggerFactory +from dubbo.logger.constants import Level +from dubbo.protocol.triple.call.server_call import TripleServerCall +from dubbo.protocol.triple.coders import TriDecoder, TriEncoder +from dubbo.protocol.triple.constants import ( + GRpcCode, + TripleHeaderName, + TripleHeaderValue, +) +from dubbo.protocol.triple.status import TriRpcStatus +from dubbo.protocol.triple.stream import ServerStream +from dubbo.proxy.handlers import RpcMethodHandler, RpcServiceHandler +from dubbo.remoting.aio.http2.headers import Http2Headers, HttpMethod +from dubbo.remoting.aio.http2.registries import Http2ErrorCode, HttpStatus +from dubbo.remoting.aio.http2.stream import Http2Stream + +__all__ = ["ServerTransportListener", "TripleServerStream"] + +_LOGGER = loggerFactory.get_logger(__name__) + + +class TripleServerStream(ServerStream): + + def __init__(self, stream: Http2Stream): + self._stream = stream + + self._tri_encoder = TriEncoder(Identity()) + + self._rst = False + self._headers_sent = False + self._trailers_sent = False + + @property + def rst(self) -> bool: + return self._rst + + @rst.setter + def rst(self, value: bool) -> None: + self._rst = value + + @property + def headers_sent(self) -> bool: + return self._headers_sent + + @property + def trailers_sent(self) -> bool: + return self._trailers_sent + + def set_compression(self, compression: str) -> None: + if compression == Identity.get_message_encoding(): + return + try: + decompressor = extensionLoader.get_extension(Decompressor, compression)() + self._tri_encoder.compressor = decompressor + except ExtensionError: + _LOGGER.warning(f"Unsupported compression: {compression}") + self.cancel_by_local( + TriRpcStatus(GRpcCode.INTERNAL, description="Unsupported compression") + ) + + def send_headers(self, headers: Http2Headers) -> None: + if not self.headers_sent: + self._stream.send_headers(headers) + self._headers_sent = True + + def send_message(self, data: bytes, compress_flag: bool) -> None: + # encode the message + encoded_data = self._tri_encoder.encode(data, compress_flag) + self._stream.send_data(encoded_data, end_stream=False) + + def complete(self, status: TriRpcStatus, attachments: Dict[str, Any]) -> None: + trailers = Http2Headers() + if not self.headers_sent: + trailers.status = HttpStatus.OK.value + trailers.add( + TripleHeaderName.CONTENT_TYPE.value, + TripleHeaderValue.APPLICATION_GRPC_PROTO.value, + ) + + # add attachments + [trailers.add(k, v) for k, v in attachments.items()] + + # add status + trailers.add(TripleHeaderName.GRPC_STATUS.value, status.code.value) + if status.code is not GRpcCode.OK: + trailers.add( + TripleHeaderName.GRPC_MESSAGE.value, + TriRpcStatus.limit_desc(status.description), + ) + + # send trailers + self._headers_sent = True + self._trailers_sent = True + self._stream.send_headers(trailers, end_stream=True) + + def cancel_by_local(self, status: TriRpcStatus) -> None: + if _LOGGER.is_enabled_for(Level.DEBUG): + _LOGGER.debug(f"Cancel stream:{self._stream} by local: {status}") + + if not self._rst: + self._rst = True + self._stream.cancel_by_local(Http2ErrorCode.CANCEL) + + +class ServerTransportListener(Http2Stream.Listener): + """ + ServerTransportListener is a callback interface that receives events on the stream. + """ + + def __init__(self, service_handles: Dict[str, RpcServiceHandler]): + super().__init__() + self._listener: Optional[ServerStream.Listener] = None + self._decoder: Optional[TriDecoder] = None + self._service_handles = service_handles + + def on_headers(self, headers: Http2Headers, end_stream: bool) -> None: + # check http method + if headers.method != HttpMethod.POST.value: + self._response_plain_text_error( + HttpStatus.METHOD_NOT_ALLOWED.value, + TriRpcStatus( + GRpcCode.INTERNAL, + description=f"Method {headers.method} is not supported", + ), + ) + return + + # check content type + content_type = headers.get(TripleHeaderName.CONTENT_TYPE.value, "") + if not content_type.startswith(TripleHeaderValue.APPLICATION_GRPC.value): + self._response_plain_text_error( + HttpStatus.UNSUPPORTED_MEDIA_TYPE.value, + TriRpcStatus( + GRpcCode.UNIMPLEMENTED, + description=( + f"Content-Type {content_type} is not supported" + if content_type + else "Content-Type is missing from the request" + ), + ), + ) + return + + # check path + path = headers.path + if not path: + self._response_plain_text_error( + HttpStatus.NOT_FOUND.value, + TriRpcStatus( + GRpcCode.UNIMPLEMENTED, + description="Expected path but is missing", + ), + ) + return + elif not path.startswith("/"): + self._response_plain_text_error( + HttpStatus.NOT_FOUND.value, + TriRpcStatus( + GRpcCode.UNIMPLEMENTED, + description=f"Expected path to start with /: {path}", + ), + ) + return + + # split the path + parts = path.split("/") + if len(parts) != 3: + self._response_error( + TriRpcStatus( + GRpcCode.UNIMPLEMENTED, description=f"Bad path format: {path}" + ) + ) + return + + service_name, method_name = parts[1], parts[2] + + # get method handler + handler = self._get_handler(service_name, method_name) + if not handler: + self._response_error( + TriRpcStatus( + GRpcCode.UNIMPLEMENTED, + description=f"Service {service_name} is not found", + ) + ) + return + + if end_stream: + # Invalid request, ignore it. + return + + decompressor: Decompressor = Identity() + message_encoding = headers.get(TripleHeaderName.GRPC_ENCODING.value) + if message_encoding and message_encoding != decompressor.get_message_encoding(): + # update decompressor + try: + decompressor = extensionLoader.get_extension( + Decompressor, message_encoding + )() + except ExtensionError: + self._response_error( + TriRpcStatus( + GRpcCode.UNIMPLEMENTED, + description=f"Grpc-encoding '{message_encoding}' is not supported", + ) + ) + return + + # create a server call + self._listener = TripleServerCall(TripleServerStream(self._stream), handler) + + # create a decoder + self._decoder = TriDecoder( + ServerTransportListener.ServerDecoderListener(self._listener), decompressor + ) + + # deliver the headers to the listener + self._listener.on_headers(headers.to_dict()) + + def _get_handler( + self, service_name: str, method_name: str + ) -> Optional[RpcMethodHandler]: + """ + Get the method handler. + :param service_name: The service name + :type service_name: str + :param method_name: The method name + :type method_name: str + :return: The method handler + :rtype: Optional[RpcMethodHandler] + """ + if self._service_handles: + service_handler = self._service_handles.get(service_name) + if service_handler: + return service_handler.method_handlers.get(method_name) + return None + + def on_data(self, data: bytes, end_stream: bool) -> None: + if self._decoder: + self._decoder.decode(data) + if end_stream: + self._decoder.close() + + def cancel_by_remote(self, error_code: Http2ErrorCode) -> None: + if self._listener: + self._listener.on_cancel_by_remote( + TriRpcStatus( + GRpcCode.CANCELLED, + description=f"Canceled by client ,errorCode= {error_code.value}", + ) + ) + + def _response_plain_text_error(self, code: int, status: TriRpcStatus) -> None: + """ + Error before create server stream, http plain text will be returned. + :param code: The error code + :type code: int + :param status: The status + :type status: TriRpcStatus + """ + # create headers + headers = Http2Headers() + headers.status = code + headers.add(TripleHeaderName.GRPC_STATUS.value, status.code.value) + headers.add(TripleHeaderName.GRPC_MESSAGE.value, status.description) + headers.add( + TripleHeaderName.CONTENT_TYPE.value, TripleHeaderValue.TEXT_PLAIN_UTF8.value + ) + + # send headers + self._stream.send_headers(headers, end_stream=True) + + def _response_error(self, status: TriRpcStatus) -> None: + """ + Error after create server stream, grpc error will be returned. + :param status: The status + :type status: TriRpcStatus + """ + # create trailers + trailers = Http2Headers() + trailers.status = HttpStatus.OK.value + trailers.add(TripleHeaderName.GRPC_STATUS.value, status.code.value) + trailers.add(TripleHeaderName.GRPC_MESSAGE.value, status.description) + trailers.add( + TripleHeaderName.CONTENT_TYPE.value, + TripleHeaderValue.APPLICATION_GRPC_PROTO.value, + ) + + # send trailers + self._stream.send_headers(trailers, end_stream=True) + + class ServerDecoderListener(TriDecoder.Listener): + """ + ServerDecoderListener is a callback interface that receives events on the decoder. + """ + + def __init__(self, listener: ServerStream.Listener): + self._listener = listener + + def on_message(self, message: bytes) -> None: + self._listener.on_message(message) + + def close(self): + self._listener.on_complete() diff --git a/dubbo/protocol/triple/tri_invoker.py b/dubbo/protocol/triple/tri_invoker.py deleted file mode 100644 index c23bf7f..0000000 --- a/dubbo/protocol/triple/tri_invoker.py +++ /dev/null @@ -1,140 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -from dubbo.compressor.compression import Compression -from dubbo.constants import common_constants -from dubbo.extension import extensionLoader -from dubbo.logger.logger_factory import loggerFactory -from dubbo.protocol.invocation import Invocation, RpcInvocation -from dubbo.protocol.invoker import Invoker -from dubbo.protocol.result import Result -from dubbo.protocol.triple.client.calls import TriClientCall -from dubbo.protocol.triple.client.stream_listener import TriClientStreamListener -from dubbo.protocol.triple.tri_constants import TripleHeaderName, TripleHeaderValue -from dubbo.protocol.triple.tri_results import TriResult -from dubbo.remoting.aio.http2.headers import Http2Headers, MethodType -from dubbo.remoting.aio.http2.stream_handler import StreamClientMultiplexHandler -from dubbo.remoting.transporter import Client -from dubbo.url import URL - -logger = loggerFactory.get_logger(__name__) - - -class TriInvoker(Invoker): - """ - Triple invoker. - """ - - def __init__( - self, url: URL, client: Client, stream_multiplexer: StreamClientMultiplexHandler - ): - self._url = url - self._client = client - self._stream_multiplexer = stream_multiplexer - - self._compression: Optional[Compression] = None - compression_type = url.get_parameter(common_constants.COMPRESSION) - if compression_type: - self._compression = extensionLoader.get_extension( - Compression, compression_type - ) - - self._destroyed = False - - def invoke(self, invocation: RpcInvocation) -> Result: - call_type = invocation.get_attribute(common_constants.CALL_KEY) - result = TriResult(call_type) - - if not self._client.is_connected(): - # Reconnect the client - self._client.reconnect() - - # Create a new TriClientCall - tri_client_call = TriClientCall( - result, - serialization=invocation.get_attribute(common_constants.SERIALIZATION), - compression=self._compression, - ) - - # Create a new stream - stream = self._stream_multiplexer.create( - TriClientStreamListener(tri_client_call.listener, self._compression) - ) - tri_client_call.bind_stream(stream) - - if call_type in ( - common_constants.CALL_UNARY, - common_constants.CALL_SERVER_STREAM, - ): - self._invoke_unary(tri_client_call, invocation) - elif call_type in ( - common_constants.CALL_CLIENT_STREAM, - common_constants.CALL_BIDI_STREAM, - ): - self._invoke_stream(tri_client_call, invocation) - - return result - - def _invoke_unary(self, call: TriClientCall, invocation: Invocation) -> None: - call.send_headers(self._create_headers(invocation)) - call.send_message(invocation.get_argument(), last=True) - - def _invoke_stream(self, call: TriClientCall, invocation: Invocation) -> None: - call.send_headers(self._create_headers(invocation)) - next_message = None - for message in invocation.get_argument(): - if next_message is not None: - call.send_message(next_message, last=False) - next_message = message - call.send_message(next_message, last=True) - - def _create_headers(self, invocation: Invocation) -> Http2Headers: - - headers = Http2Headers() - headers.scheme = TripleHeaderValue.HTTP.value - headers.method = MethodType.POST - headers.authority = self._url.location - # set path - path = "" - if invocation.get_service_name(): - path += f"/{invocation.get_service_name()}" - path += f"/{invocation.get_method_name()}" - headers.path = path - - # set content type - headers.content_type = TripleHeaderValue.APPLICATION_GRPC_PROTO.value - - # set te - headers.add(TripleHeaderName.TE.value, TripleHeaderValue.TRAILERS.value) - - return headers - - def get_url(self) -> URL: - return self._url - - def is_available(self) -> bool: - return self._client.is_connected() - - @property - def destroyed(self) -> bool: - return self._destroyed - - def destroy(self) -> None: - self._client.close() - self._client = None - self._stream_multiplexer = None - self._url = None diff --git a/dubbo/protocol/triple/tri_protocol.py b/dubbo/protocol/triple/tri_protocol.py deleted file mode 100644 index 4c28625..0000000 --- a/dubbo/protocol/triple/tri_protocol.py +++ /dev/null @@ -1,61 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from concurrent.futures import ThreadPoolExecutor - -from dubbo.constants import common_constants -from dubbo.extension import extensionLoader -from dubbo.logger.logger_factory import loggerFactory -from dubbo.protocol.invoker import Invoker -from dubbo.protocol.protocol import Protocol -from dubbo.protocol.triple.tri_invoker import TriInvoker -from dubbo.remoting.aio.http2.protocol import Http2Protocol -from dubbo.remoting.aio.http2.stream_handler import StreamClientMultiplexHandler -from dubbo.remoting.transporter import Transporter -from dubbo.url import URL - -logger = loggerFactory.get_logger(__name__) - - -class TripleProtocol(Protocol): - - def __init__(self, url: URL): - self._url = url - self._transporter: Transporter = extensionLoader.get_extension( - Transporter, - self._url.get_parameter(common_constants.TRANSPORTER_KEY) or "aio", - )() - self._invokers = [] - - def refer(self, url: URL) -> Invoker: - """ - Refer a remote service. - Args: - url (URL): The URL of the remote service. - """ - executor = ThreadPoolExecutor(thread_name_prefix="dubbo-tri-") - # Create a stream handler - stream_multiplexer = StreamClientMultiplexHandler(executor) - # set stream handler and protocol - url.attributes[common_constants.TRANSPORTER_STREAM_HANDLER_KEY] = ( - stream_multiplexer - ) - url.attributes[common_constants.TRANSPORTER_PROTOCOL_KEY] = Http2Protocol - - # Create a client - client = self._transporter.connect(url) - invoker = TriInvoker(url, client, stream_multiplexer) - self._invokers.append(invoker) - return invoker diff --git a/dubbo/protocol/triple/client/__init__.py b/dubbo/proxy/__init__.py similarity index 87% rename from dubbo/protocol/triple/client/__init__.py rename to dubbo/proxy/__init__.py index bcba37a..4c4ddd8 100644 --- a/dubbo/protocol/triple/client/__init__.py +++ b/dubbo/proxy/__init__.py @@ -13,3 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from ._interfaces import RpcCallable, RpcCallableFactory + +__all__ = ["RpcCallable", "RpcCallableFactory"] diff --git a/dubbo/proxy/_interfaces.py b/dubbo/proxy/_interfaces.py new file mode 100644 index 0000000..d6c9c98 --- /dev/null +++ b/dubbo/proxy/_interfaces.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +from dubbo.common import URL +from dubbo.protocol import Invoker +from dubbo.proxy.handlers import RpcServiceHandler + +__all__ = [ + "RpcCallable", + "RpcCallableFactory", +] + + +class RpcCallable(abc.ABC): + + @abc.abstractmethod + def __call__(self, *args, **kwargs): + """ + call the rpc service + """ + raise NotImplementedError() + + +class RpcCallableFactory(abc.ABC): + + @abc.abstractmethod + def get_callable(self, invoker: Invoker, url: URL) -> RpcCallable: + """ + get the rpc proxy + :param invoker: the invoker. + :type invoker: Invoker + :param url: the url. + :type url: URL + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_invoker(self, service_handler: RpcServiceHandler, url: URL) -> Invoker: + """ + get the rpc invoker + :param service_handler: the service handler. + :type service_handler: RpcServiceHandler + :param url: the url. + :type url: URL + """ + raise NotImplementedError() diff --git a/dubbo/callable.py b/dubbo/proxy/callables.py similarity index 57% rename from dubbo/callable.py rename to dubbo/proxy/callables.py index 0481818..5f17098 100644 --- a/dubbo/callable.py +++ b/dubbo/proxy/callables.py @@ -13,24 +13,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Any -from dubbo.constants import common_constants +from dubbo.common import constants as common_constants +from dubbo.common.url import URL +from dubbo.protocol import Invoker from dubbo.protocol.invocation import RpcInvocation -from dubbo.protocol.invoker import Invoker -from dubbo.url import URL +from dubbo.proxy import RpcCallable, RpcCallableFactory + +__all__ = ["MultipleRpcCallable"] + +from dubbo.proxy.handlers import RpcServiceHandler -class AbstractRpcCallable: +class MultipleRpcCallable(RpcCallable): + """ + The RpcCallable class. + """ def __init__(self, invoker: Invoker, url: URL): self._invoker = invoker self._url = url self._service_name = self._url.path - self._method_name = self._url.get_parameter(common_constants.METHOD_KEY) - self._call_type = self._url.get_parameter(common_constants.CALL_KEY) + self._method_name = self._url.parameters[common_constants.METHOD_KEY] + self._call_type = self._url.parameters[common_constants.CALL_KEY] - self._serialization = self._url.attributes[common_constants.SERIALIZATION] + self._serializer = self._url.attributes[common_constants.SERIALIZER_KEY] + self._deserializer = self._url.attributes[common_constants.DESERIALIZER_KEY] def _create_invocation(self, argument: Any) -> RpcInvocation: return RpcInvocation( @@ -39,16 +49,26 @@ def _create_invocation(self, argument: Any) -> RpcInvocation: argument, attributes={ common_constants.CALL_KEY: self._call_type, - common_constants.SERIALIZATION: self._serialization, + common_constants.SERIALIZER_KEY: self._serializer, + common_constants.DESERIALIZER_KEY: self._deserializer, }, ) - -class RpcCallable(AbstractRpcCallable): - def __call__(self, argument: Any) -> Any: # Create a new RpcInvocation invocation = self._create_invocation(argument) # Do invoke. result = self._invoker.invoke(invocation) return result.value() + + +class DefaultRpcCallableFactory(RpcCallableFactory): + """ + The RpcCallableFactory class. + """ + + def get_callable(self, invoker: Invoker, url: URL) -> RpcCallable: + return MultipleRpcCallable(invoker, url) + + def get_invoker(self, service_handler: RpcServiceHandler, url: URL) -> Invoker: + pass diff --git a/dubbo/proxy/handlers.py b/dubbo/proxy/handlers.py new file mode 100644 index 0000000..26fbce0 --- /dev/null +++ b/dubbo/proxy/handlers.py @@ -0,0 +1,136 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Optional + +from dubbo.common import constants as common_constants +from dubbo.common.types import DeserializingFunction, SerializingFunction + +__all__ = ["RpcMethodHandler", "RpcServiceHandler"] + + +class RpcMethodHandler: + """ + Rpc method handler + """ + + def __init__( + self, + call_type: str, + behavior: Callable, + request_serializer: Optional[SerializingFunction] = None, + response_serializer: Optional[DeserializingFunction] = None, + ): + """ + Initialize the RpcMethodHandler + :param call_type: the call type. + :type call_type: str + :param behavior: the behavior of the method. + :type behavior: Callable + :param request_serializer: the request serializer. + :type request_serializer: Optional[SerializingFunction] + :param response_serializer: the response serializer. + :type response_serializer: Optional[DeserializingFunction] + """ + self.call_type = call_type + self.behavior = behavior + self.request_serializer = request_serializer + self.response_serializer = response_serializer + + @classmethod + def unary( + cls, + behavior: Callable, + request_serializer: Optional[SerializingFunction] = None, + response_serializer: Optional[DeserializingFunction] = None, + ): + """ + Create a unary method handler + """ + return cls( + common_constants.UNARY_CALL_VALUE, + behavior, + request_serializer, + response_serializer, + ) + + @classmethod + def client_stream( + cls, + behavior: Callable, + request_serializer: SerializingFunction, + response_serializer: DeserializingFunction, + ): + """ + Create a client stream method handler + """ + return cls( + common_constants.CLIENT_STREAM_CALL_VALUE, + behavior, + request_serializer, + response_serializer, + ) + + @classmethod + def server_stream( + cls, + behavior: Callable, + request_serializer: SerializingFunction, + response_serializer: DeserializingFunction, + ): + """ + Create a server stream method handler + """ + return cls( + common_constants.SERVER_STREAM_CALL_VALUE, + behavior, + request_serializer, + response_serializer, + ) + + @classmethod + def bi_stream( + cls, + behavior: Callable, + request_serializer: SerializingFunction, + response_serializer: DeserializingFunction, + ): + """ + Create a bidi stream method handler + """ + return cls( + common_constants.BI_STREAM_CALL_VALUE, + behavior, + request_serializer, + response_serializer, + ) + + +class RpcServiceHandler: + """ + Rpc service handler + """ + + def __init__(self, service_name: str, method_handlers: Dict[str, RpcMethodHandler]): + """ + Initialize the RpcServiceHandler + :param service_name: the name of the service. + :type service_name: str + :param method_handlers: the method handlers. + :type method_handlers: Dict[str, RpcMethodHandler] + """ + self.service_name = service_name + self.method_handlers = method_handlers diff --git a/dubbo/compressor/__init__.py b/dubbo/registry/__init__.py similarity index 93% rename from dubbo/compressor/__init__.py rename to dubbo/registry/__init__.py index bcba37a..52dfd01 100644 --- a/dubbo/compressor/__init__.py +++ b/dubbo/registry/__init__.py @@ -13,3 +13,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from ._interfaces import Registry, RegistryFactory diff --git a/dubbo/registry/_interfaces.py b/dubbo/registry/_interfaces.py new file mode 100644 index 0000000..3902208 --- /dev/null +++ b/dubbo/registry/_interfaces.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +from dubbo.common import URL, Node + +__all__ = ["Registry", "RegistryFactory"] + + +class Registry(Node, abc.ABC): + + @abc.abstractmethod + def register(self, url: URL) -> None: + """ + Register a service to registry. + + :param URL url: The service URL. + :return: None + """ + raise NotImplementedError() + + @abc.abstractmethod + def unregister(self, url: URL) -> None: + """ + Unregister a service from registry. + + :param URL url: The service URL. + """ + raise NotImplementedError() + + @abc.abstractmethod + def subscribe(self, url: URL, listener): + """ + Subscribe a service from registry. + :param URL url: The service URL. + :param listener: The listener to notify when service changed. + """ + raise NotImplementedError() + + @abc.abstractmethod + def unsubscribe(self, url: URL, listener): + """ + Unsubscribe a service from registry. + :param URL url: The service URL. + :param listener: The listener to notify when service changed. + """ + raise NotImplementedError() + + @abc.abstractmethod + def lookup(self, url: URL): + """ + Lookup a service from registry. + :param URL url: The service URL. + """ + raise NotImplementedError() + + +class RegistryFactory(abc.ABC): + + @abc.abstractmethod + def get_registry(self, url: URL) -> Registry: + """ + Get a registry instance. + + :param URL url: The registry URL. + :return: The registry instance. + """ + raise NotImplementedError() diff --git a/tests/test_dubbo.py b/dubbo/registry/zookeeper/__init__.py similarity index 85% rename from tests/test_dubbo.py rename to dubbo/registry/zookeeper/__init__.py index a9cdebd..a1af7e7 100644 --- a/tests/test_dubbo.py +++ b/dubbo/registry/zookeeper/__init__.py @@ -13,12 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import unittest - -class TestDubbo(unittest.TestCase): - - def test_dubbo(self): - from dubbo import Dubbo - - Dubbo() +from ._interfaces import ( + ChildrenListener, + DataListener, + StateListener, + ZookeeperClient, + ZookeeperTransport, +) diff --git a/dubbo/registry/zookeeper/_interfaces.py b/dubbo/registry/zookeeper/_interfaces.py new file mode 100644 index 0000000..f2292e6 --- /dev/null +++ b/dubbo/registry/zookeeper/_interfaces.py @@ -0,0 +1,251 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import enum + +from dubbo.common import URL + +__all__ = [ + "StateListener", + "DataListener", + "ChildrenListener", + "ZookeeperClient", + "ZookeeperTransport", +] + + +class StateListener(abc.ABC): + class State(enum.Enum): + """ + Zookeeper connection state. + """ + + SUSPENDED = "SUSPENDED" + CONNECTED = "CONNECTED" + LOST = "LOST" + + @abc.abstractmethod + def state_changed(self, state: "StateListener.State") -> None: + """ + Notify when connection state changed. + + :param StateListener.State state: The new connection state. + """ + raise NotImplementedError() + + +class DataListener(abc.ABC): + class EventType(enum.Enum): + """ + Zookeeper data event type. + """ + + CREATED = "CREATED" + DELETED = "DELETED" + CHANGED = "CHANGED" + CHILD = "CHILD" + NONE = "NONE" + + @abc.abstractmethod + def data_changed( + self, path: str, data: bytes, event_type: "DataListener.EventType" + ) -> None: + """ + Notify when data changed. + + :param str path: The node path. + :param bytes data: The new data. + :param DataListener.EventType event_type: The event type. + """ + raise NotImplementedError() + + +class ChildrenListener(abc.ABC): + @abc.abstractmethod + def children_changed(self, path: str, children: list) -> None: + """ + Notify when children changed. + + :param str path: The node path. + :param list children: The new children. + """ + raise NotImplementedError() + + +class ZookeeperClient(abc.ABC): + """ + Zookeeper Client interface. + """ + + __slots__ = ["_url"] + + def __init__(self, url: URL): + """ + Initialize the zookeeper client. + + :param URL url: The zookeeper URL. + """ + self._url = url + + @abc.abstractmethod + def start(self) -> None: + """ + Start the zookeeper client. + """ + raise NotImplementedError() + + @abc.abstractmethod + def stop(self) -> None: + """ + Stop the zookeeper client. + """ + raise NotImplementedError() + + @abc.abstractmethod + def is_connected(self) -> bool: + """ + Check if the client is connected to zookeeper. + + :return: True if connected, False otherwise. + """ + raise NotImplementedError() + + @abc.abstractmethod + def create(self, path: str, ephemeral=False) -> None: + """ + Create a node in zookeeper. + + :param str path: The node path. + :param bool ephemeral: Whether the node is ephemeral. False: persistent, True: ephemeral. + """ + raise NotImplementedError() + + @abc.abstractmethod + def create_or_update(self, path: str, data: bytes, ephemeral=False) -> None: + """ + Create or update a node in zookeeper. + + :param str path: The node path. + :param bytes data: The node data. + :param bool ephemeral: Whether the node is ephemeral. False: persistent, True: ephemeral. + """ + raise NotImplementedError() + + @abc.abstractmethod + def check_exist(self, path: str) -> bool: + """ + Check if a node exists in zookeeper. + + :param str path: The node path. + :return: True if the node exists, False otherwise. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_data(self, path: str) -> bytes: + """ + Get data of a node in zookeeper. + + :param str path: The node path. + :return: The node data. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_children(self, path: str) -> list: + """ + Get children of a node in zookeeper. + + :param str path: The node path. + :return: The children of the node. + """ + raise NotImplementedError() + + @abc.abstractmethod + def delete(self, path: str) -> None: + """ + Delete a node in zookeeper. + + :param str path: The node path. + """ + raise NotImplementedError() + + @abc.abstractmethod + def add_state_listener(self, listener: StateListener) -> None: + """ + Add a state listener to zookeeper. + + :param StateListener listener: The listener to notify when connection state changed. + """ + raise NotImplementedError() + + @abc.abstractmethod + def remove_state_listener(self, listener: StateListener) -> None: + """ + Remove a state listener from zookeeper. + + :param StateListener listener: The listener to remove. + """ + raise NotImplementedError() + + @abc.abstractmethod + def add_data_listener(self, path: str, listener: DataListener) -> None: + """ + Add a data listener to a node in zookeeper. + + :param str path: The node path. + :param DataListener listener: The listener to notify when data changed. + """ + raise NotImplementedError() + + @abc.abstractmethod + def remove_data_listener(self, listener: DataListener) -> None: + """ + Remove a data listener from a node in zookeeper. + + :param DataListener listener: The listener to remove. + """ + raise NotImplementedError() + + @abc.abstractmethod + def add_children_listener(self, path: str, listener: ChildrenListener) -> None: + """ + Add a children listener to a node in zookeeper. + + :param str path: The node path. + :param ChildrenListener listener: The listener to notify when children changed. + """ + raise NotImplementedError() + + @abc.abstractmethod + def remove_children_listener(self, listener: ChildrenListener) -> None: + """ + Remove a children listener from a node in zookeeper. + + :param ChildrenListener listener: The listener to remove. + """ + raise NotImplementedError() + + +class ZookeeperTransport(abc.ABC): + + @abc.abstractmethod + def connect(self, url: URL) -> ZookeeperClient: + """ + Connect to a zookeeper. + """ + raise NotImplementedError() diff --git a/dubbo/registry/zookeeper/kazoo_transport.py b/dubbo/registry/zookeeper/kazoo_transport.py new file mode 100644 index 0000000..8bf678e --- /dev/null +++ b/dubbo/registry/zookeeper/kazoo_transport.py @@ -0,0 +1,427 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import threading +from typing import Dict, List, Union + +from kazoo.client import KazooClient +from kazoo.protocol.states import EventType, KazooState, WatchedEvent, ZnodeStat + +from dubbo.common import URL +from dubbo.logger import loggerFactory + +from ._interfaces import ( + ChildrenListener, + DataListener, + StateListener, + ZookeeperClient, + ZookeeperTransport, +) + +__all__ = ["KazooZookeeperClient", "KazooZookeeperTransport"] + +_LOGGER = loggerFactory.get_logger(__name__) + +LISTENER_TYPE = Union[StateListener, DataListener, ChildrenListener] + + +class AbstractListenerAdapter(abc.ABC): + """ + Abstract listener adapter. + + This abstract class defines a template for listener adapters, providing thread-safe methods to + reset and remove listeners. Concrete implementations should provide specific behavior for these methods. + """ + + __slots__ = ["_lock", "_listener"] + + def __init__(self, listener: LISTENER_TYPE): + """ + Initialize the adapter with a reentrant lock to ensure thread safety. + :param listener: The listener. + :type listener: StateListener or DataListener or ChildrenListener + """ + self._lock = threading.Lock() + self._listener = listener + + def get_listener(self) -> LISTENER_TYPE: + """ + Get the listener. + :return: The listener. + :rtype: StateListener or DataListener or ChildrenListener + """ + return self._listener + + def reset(self, listener: LISTENER_TYPE) -> None: + """ + Reset with a new listener. + + :param listener: The new listener to set. + :type listener: StateListener or DataListener or ChildrenListener + """ + with self._lock: + self._listener = listener + + def remove(self) -> None: + """ + Remove the current listener. + + """ + with self._lock: + self._listener = None + + +class AbstractListenerAdapterFactory(abc.ABC): + """ + Abstract factory for creating and managing listener adapters. + + This abstract factory class provides methods to create and remove listener adapters in a + thread-safe manner. It maintains dictionaries to track active and inactive adapters. + """ + + __slots__ = [ + "_client", + "_lock", + "_listener_to_path", + "_active_adapters", + "_inactive_adapters", + ] + + def __init__(self, client: KazooClient): + """ + Initialize the factory with a KazooClient and set up the necessary locks and dictionaries. + + :param client: An instance of KazooClient to manage Zookeeper connections. + :type client: KazooClient + """ + self._client = client + self._lock = threading.Lock() + + self._listener_to_path = {} + self._active_adapters: Dict[str, AbstractListenerAdapter] = {} + self._inactive_adapters: Dict[str, AbstractListenerAdapter] = {} + + def create(self, path: str, listener) -> None: + """ + Create a new adapter or re-enable an inactive one. + + This method checks if the listener already has an active or inactive adapter. If the adapter is + inactive, it re-enables it. Otherwise, it creates a new adapter using the abstract `do_create` method. + + :param path: The Znode path to watch. + :type path: str + :param listener: The listener for which to create or re-enable an adapter. + :type listener: Any + """ + with self._lock: + adapter = self._active_adapters.pop(path, None) + if adapter is not None: + if adapter.get_listener() == listener: + return + else: + # replace the listener + adapter.reset(listener) + elif path in self._inactive_adapters: + # Re-enabling inactive adapter + adapter = self._inactive_adapters.pop(path) + adapter.reset(listener) + else: + # Creating a new adapter + adapter = self.do_create(path, listener) + + self._listener_to_path[listener] = path + self._active_adapters[path] = adapter + + def remove(self, listener) -> None: + """ + Remove the current listener and move its adapter to the inactive dictionary. + + This method removes the adapter associated with the listener from the active dictionary, + calls its `remove` method, and then stores it in the inactive dictionary. + + :param listener: The listener whose adapter is to be removed. + :type listener: Any + """ + with self._lock: + path = self._listener_to_path.pop(listener, None) + if path is None: + return + adapter = self._active_adapters.pop(path) + if adapter is not None: + adapter.remove() + self._inactive_adapters[path] = adapter + + @abc.abstractmethod + def do_create(self, path: str, listener) -> AbstractListenerAdapter: + """ + Define the creation of a new adapter. + + This abstract method must be implemented by subclasses to handle the actual creation logic + for a new adapter. + + :param path: The Znode path to watch. + :type path: str + :param listener: The listener for which to create a new adapter. + :type listener: Any + :return: A new instance of an AbstractListenerAdapter. + :rtype: AbstractListenerAdapter + :raises NotImplementedError: If the method is not implemented by a subclass. + """ + raise NotImplementedError() + + +class StateListenerAdapter(AbstractListenerAdapter): + """ + State listener adapter. + + This adapter inherits from :class:`AbstractListenerAdapter`, but it does not need to use the `reset` + and `remove` methods. The :class:`KazooClient` provides the `add_listener` and `remove_listener` + methods, which can effectively replace these methods. + + Note: + The `add_listener` and `remove_listener` methods of :class:`KazooClient` offer a more efficient + and straightforward way to manage state listeners, making the `reset` and `remove` methods redundant. + """ + + def __init__(self, listener: StateListener): + super().__init__(listener) + + def __call__(self, state: KazooState): + """ + Handle state changes and notify the listener. + + This method is called with the current state of the KazooClient. + + :param state: The current state of the KazooClient. + :type state: KazooState + """ + if state == KazooState.CONNECTED: + state = StateListener.State.CONNECTED + elif state == KazooState.LOST: + state = StateListener.State.LOST + elif state == KazooState.SUSPENDED: + state = StateListener.State.SUSPENDED + + self._listener.state_changed(state) + + +class DataListenerAdapter(AbstractListenerAdapter): + """ + Data listener adapter. + + This adapter handles data change events from a specified Znode path and notifies a `DataListener`. + It should be used in conjunction with `AbstractListenerAdapterFactory` to manage listener creation + and removal. + """ + + __slots__ = ["_path"] + + def __init__(self, path: str, listener: DataListener): + """ + Initialize the KazooDataListenerAdapter with a given path and listener. + + :param path: The Znode path to watch. + :type path: str + :param listener: The data listener to notify on data changes. + :type listener: DataListener + """ + super().__init__(listener) + self._path = path + + def __call__(self, data: bytes, stat: ZnodeStat, event: WatchedEvent): + """ + Handle data changes and notify the listener. + + This method is called with the current data, stat, and event of the watched Znode. + + :param data: The current data of the Znode. + :type data: bytes + :param stat: The status of the Znode. + :type stat: ZnodeStat + :param event: The event that triggered the callback. + :type event: WatchedEvent + """ + with self._lock: + if event is None or self._listener is None: + # This callback is called once immediately after being added, and at this point, event is None. + # Since a non-existent node also returns None, to avoid handling unknown None exceptions, + # we directly filter out all cases of None. + return + + event_type = None + if event.type == EventType.NONE: + event_type = DataListener.EventType.NONE + elif event.type == EventType.CREATED: + event_type = DataListener.EventType.CREATED + elif event.type == EventType.DELETED: + event_type = DataListener.EventType.DELETED + elif event.type == EventType.CHANGED: + event_type = DataListener.EventType.CHANGED + elif event.type == EventType.CHILD: + event_type = DataListener.EventType.CHILD + + self._listener.data_changed(self._path, data, event_type) + + +class ChildrenListenerAdapter(AbstractListenerAdapter): + """ + Children listener adapter. + + This adapter handles children change events from a specified Znode path and notifies a `ChildrenListener`. + It should be used in conjunction with `AbstractListenerAdapterFactory` to manage listener creation and removal. + """ + + def __init__(self, path: str, listener: ChildrenListener): + """ + Initialize the ChildrenListenerAdapter with a given path and listener. + + :param path: The Znode path to watch. + :type path: str + :param listener: The children listener to notify on children changes. + :type listener: ChildrenListener + """ + super().__init__(listener) + self._path = path + + def __call__(self, children: List[str]): + """ + Handle children changes and notify the listener. + + This method is called with the current list of children of the watched Znode. + + :param children: The current list of children of the Znode. + :type children: List[str] + """ + with self._lock: + if self._listener is not None: + self._listener.children_changed(self._path, children) + + +class DataListenerAdapterFactory(AbstractListenerAdapterFactory): + + def do_create(self, path: str, listener: DataListener) -> AbstractListenerAdapter: + data_adapter = DataListenerAdapter(path, listener) + self._client.DataWatch(path, data_adapter) + return data_adapter + + +class ChildrenListenerAdapterFactory(AbstractListenerAdapterFactory): + + def do_create( + self, path: str, listener: ChildrenListener + ) -> AbstractListenerAdapter: + children_adapter = ChildrenListenerAdapter(path, listener) + self._client.ChildrenWatch(path, children_adapter) + return children_adapter + + +class KazooZookeeperClient(ZookeeperClient): + """ + Kazoo Zookeeper client. + """ + + def __init__(self, url: URL): + super().__init__(url) + self._client: KazooClient = KazooClient(hosts=url.location) + # TODO: Add more attributes from url + + # state listener dict + self._state_lock = threading.Lock() + self._state_listeners: Dict[StateListener, StateListenerAdapter] = {} + + self._data_adapter_factory = DataListenerAdapterFactory(self._client) + + self._children_adapter_factory = ChildrenListenerAdapterFactory(self._client) + + def start(self) -> None: + # start the client + self._client.start() + + def stop(self) -> None: + # stop the client + self._client.stop() + + def is_connected(self) -> bool: + return self._client.connected + + def create(self, path: str, ephemeral=False) -> None: + self._client.create(path, ephemeral=ephemeral) + + def create_or_update(self, path: str, data: bytes, ephemeral=False) -> None: + if self.check_exist(path): + self._client.set(path, data) + else: + self._client.create(path, data, ephemeral=ephemeral) + + def check_exist(self, path: str) -> bool: + return self._client.exists(path) + + def get_data(self, path: str) -> bytes: + # data: bytes, stat: ZnodeStat + data, stat = self._client.get(path) + return data + + def get_children(self, path: str) -> list: + return self._client.get_children(path) + + def delete(self, path: str) -> None: + self._client.delete(path) + + def add_state_listener(self, listener: StateListener) -> None: + with self._state_lock: + if listener in self._state_listeners: + return + state_adapter = StateListenerAdapter(listener) + self._client.add_listener(state_adapter) + self._state_listeners[listener] = state_adapter + + def remove_state_listener(self, listener: StateListener) -> None: + with self._state_lock: + state_adapter = self._state_listeners.pop(listener, None) + if state_adapter is not None: + self._client.remove_listener(state_adapter) + + def add_data_listener(self, path: str, listener: DataListener) -> None: + self._data_adapter_factory.create(path, listener) + + def remove_data_listener(self, listener: DataListener) -> None: + self._data_adapter_factory.remove(listener) + + def add_children_listener(self, path: str, listener: ChildrenListener) -> None: + self._children_adapter_factory.create(path, listener) + + def remove_children_listener(self, listener: ChildrenListener) -> None: + self._children_adapter_factory.remove(listener) + + +class KazooZookeeperTransport(ZookeeperTransport): + + def __init__(self): + self._lock = threading.Lock() + # key: location, value: KazooZookeeperClient + self._zk_client_dict: Dict[str, KazooZookeeperClient] = {} + + def connect(self, url: URL) -> ZookeeperClient: + with self._lock: + zk_client = self._zk_client_dict.get(url.location) + if zk_client is None or zk_client.is_connected(): + # Create new KazooZookeeperClient + zk_client = KazooZookeeperClient(url) + zk_client.start() + self._zk_client_dict[url.location] = zk_client + + return zk_client diff --git a/dubbo/registry/zookeeper/zk_registry.py b/dubbo/registry/zookeeper/zk_registry.py new file mode 100644 index 0000000..4b4e6c7 --- /dev/null +++ b/dubbo/registry/zookeeper/zk_registry.py @@ -0,0 +1,88 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dubbo.common import URL +from dubbo.common import constants as common_constants +from dubbo.logger import loggerFactory +from dubbo.registry import Registry, RegistryFactory + +from ._interfaces import StateListener, ZookeeperTransport +from .kazoo_transport import KazooZookeeperTransport + +_LOGGER = loggerFactory.get_logger(__name__) + + +class ZookeeperRegistry(Registry): + DEFAULT_ROOT = "dubbo" + + def __init__(self, url: URL, zk_transport: ZookeeperTransport): + self._url = url + self._zk_client = zk_transport.connect(self._url) + + self._root = self._url.parameters.get( + common_constants.GROUP_KEY, self.DEFAULT_ROOT + ) + if not self._root.startswith(common_constants.PATH_SEPARATOR): + self._root = common_constants.PATH_SEPARATOR + self._root + + class _StateListener(StateListener): + def state_changed(self, state: "StateListener.State") -> None: + if state == StateListener.State.LOST: + _LOGGER.warning("Connection lost") + elif state == StateListener.State.CONNECTED: + _LOGGER.info("Connection established") + elif state == StateListener.State.SUSPENDED: + _LOGGER.info("Connection suspended") + + self._zk_client.add_state_listener(_StateListener()) + + def register(self, url: URL) -> None: + pass + + def unregister(self, url: URL) -> None: + pass + + def subscribe(self, url: URL, listener): + pass + + def unsubscribe(self, url: URL, listener): + pass + + def lookup(self, url: URL): + pass + + def get_url(self) -> URL: + return self._url + + def is_available(self) -> bool: + return self._zk_client and self._zk_client.is_connected() + + def destroy(self) -> None: + if self._zk_client: + self._zk_client.stop() + + def check_destroy(self) -> None: + if not self._zk_client: + raise RuntimeError("registry is destroyed") + + +class ZookeeperRegistryFactory(RegistryFactory): + + def __init__(self): + self._transport: ZookeeperTransport = KazooZookeeperTransport() + + def get_registry(self, url: URL) -> Registry: + return ZookeeperRegistry(url, self._transport) diff --git a/dubbo/remoting/__init__.py b/dubbo/remoting/__init__.py index bcba37a..a93961f 100644 --- a/dubbo/remoting/__init__.py +++ b/dubbo/remoting/__init__.py @@ -13,3 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from ._interfaces import Client, Server, Transporter + +__all__ = ["Client", "Server", "Transporter"] diff --git a/dubbo/remoting/transporter.py b/dubbo/remoting/_interfaces.py similarity index 55% rename from dubbo/remoting/transporter.py rename to dubbo/remoting/_interfaces.py index f56dc5f..b2181a7 100644 --- a/dubbo/remoting/transporter.py +++ b/dubbo/remoting/_interfaces.py @@ -13,60 +13,104 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dubbo.url import URL +import abc -class Client: +from dubbo.common import URL + +__all__ = ["Client", "Server", "Transporter"] + + +class Client(abc.ABC): def __init__(self, url: URL): self._url = url + @abc.abstractmethod def is_connected(self) -> bool: """ Check if the client is connected. """ - raise NotImplementedError("is_connected() is not implemented.") + raise NotImplementedError() + @abc.abstractmethod def is_closed(self) -> bool: """ Check if the client is closed. """ - raise NotImplementedError("is_closed() is not implemented.") + raise NotImplementedError() + @abc.abstractmethod def connect(self): """ Connect to the server. """ - raise NotImplementedError("connect() is not implemented.") + raise NotImplementedError() + @abc.abstractmethod def reconnect(self): """ Reconnect to the server. """ - raise NotImplementedError("reconnect() is not implemented.") + raise NotImplementedError() + @abc.abstractmethod def close(self): """ Close the client. """ - raise NotImplementedError("close() is not implemented.") + raise NotImplementedError() class Server: - # TODO define the interface of the server. - pass + """ + Server + """ + + @abc.abstractmethod + def is_exported(self) -> bool: + """ + Check if the server is exported. + """ + raise NotImplementedError() + + @abc.abstractmethod + def is_closed(self) -> bool: + """ + Check if the server is closed. + """ + raise NotImplementedError() + + @abc.abstractmethod + def export(self): + """ + Export the server. + """ + raise NotImplementedError() + + @abc.abstractmethod + def close(self): + """ + Close the server. + """ + raise NotImplementedError() -class Transporter: +class Transporter(abc.ABC): + """ + Transporter interface + """ + @abc.abstractmethod def connect(self, url: URL) -> Client: """ Connect to a server. """ - raise NotImplementedError("connect() is not implemented.") + raise NotImplementedError() + @abc.abstractmethod def bind(self, url: URL) -> Server: """ Bind a server. """ - raise NotImplementedError("bind() is not implemented.") + raise NotImplementedError() diff --git a/dubbo/remoting/aio/aio_transporter.py b/dubbo/remoting/aio/aio_transporter.py index dc97db4..e721195 100644 --- a/dubbo/remoting/aio/aio_transporter.py +++ b/dubbo/remoting/aio/aio_transporter.py @@ -13,19 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio import concurrent -import threading from typing import Optional -from dubbo.constants import common_constants -from dubbo.logger.logger_factory import loggerFactory +from dubbo.common import constants as common_constants +from dubbo.common.url import URL +from dubbo.common.utils import FutureHelper +from dubbo.logger import loggerFactory +from dubbo.remoting._interfaces import Client, Server, Transporter +from dubbo.remoting.aio import constants as aio_constants from dubbo.remoting.aio.event_loop import EventLoop -from dubbo.remoting.aio.exceptions import RemotingException -from dubbo.remoting.transporter import Client, Server, Transporter -from dubbo.url import URL +from dubbo.remoting.aio.exceptions import RemotingError -logger = loggerFactory.get_logger(__name__) +_LOGGER = loggerFactory.get_logger(__name__) class AioClient(Client): @@ -35,6 +37,15 @@ class AioClient(Client): url(URL): The configuration of the client. """ + __slots__ = [ + "_protocol", + "_connected", + "_close_future", + "_closing", + "_closed", + "_event_loop", + ] + def __init__(self, url: URL): super().__init__(url) @@ -42,17 +53,15 @@ def __init__(self, url: URL): self._protocol = None # the event to indicate the connection status of the client - self._connect_event = threading.Event() + self._connected = False + # the event to indicate the close status of the client self._close_future = concurrent.futures.Future() self._closing = False + self._closed = False - self._url.add_parameter( - common_constants.TRANSPORTER_SIDE_KEY, - common_constants.TRANSPORTER_SIDE_CLIENT, - ) - self._url.attributes["connect-event"] = self._connect_event - self._url.attributes["close-future"] = self._close_future + self._url.parameters[common_constants.SIDE_KEY] = common_constants.CLIENT_VALUE + self._url.attributes[aio_constants.CLOSE_FUTURE_KEY] = self._close_future self._event_loop: Optional[EventLoop] = None @@ -63,20 +72,20 @@ def is_connected(self) -> bool: """ Check if the client is connected. """ - return self._connect_event.is_set() + return self._connected def is_closed(self) -> bool: """ Check if the client is closed. """ - return self._close_future.done() or self._closing + return self._closed or self._closing def reconnect(self) -> None: """ Reconnect to the server. """ self.close() - self._connect_event = threading.Event() + self._connected = False self._close_future = concurrent.futures.Future() self.connect() @@ -87,17 +96,17 @@ def connect(self) -> None: if self.is_connected(): return elif self.is_closed(): - raise RemotingException("The client is closed.") + raise RemotingError("The client is closed.") - async def _inner_operate(): + async def _inner_operation(): running_loop = asyncio.get_running_loop() + # Create the connection. _, protocol = await running_loop.create_connection( - lambda: self._url.attributes[common_constants.TRANSPORTER_PROTOCOL_KEY]( - self._url - ), + lambda: self._url.attributes[common_constants.PROTOCOL_KEY](self._url), self._url.host, self._url.port, ) + # Set the protocol. return protocol # Run the connection logic in the event loop. @@ -107,12 +116,18 @@ async def _inner_operate(): self._event_loop.start() future = asyncio.run_coroutine_threadsafe( - _inner_operate(), self._event_loop.loop + _inner_operation(), self._event_loop.loop ) try: self._protocol = future.result() + self._connected = True + _LOGGER.info( + "Connected to the server. host: %s, port: %s", + self._url.host, + self._url.port, + ) except ConnectionRefusedError as e: - raise RemotingException("Failed to connect to the server") from e + raise RemotingError("Failed to connect to the server") from e def close(self) -> None: """ @@ -120,19 +135,19 @@ def close(self) -> None: """ if self.is_closed(): return - self._closing = True + + def _on_close(_future: concurrent.futures.Future): + self._closed = True if _future.done() else False + + self._close_future.add_done_callback(_on_close) + try: self._protocol.close() - if exc := self._protocol.exception(): - raise RemotingException(f"Failed to close the client: {exc}") - except Exception as e: - if not isinstance(e, RemotingException): - # Ignore the exception if it is not RemotingException - pass - else: - # Re-raise RemotingException - raise e + exc = self._close_future.exception() + if exc: + raise RemotingError(f"Failed to close the client: {exc}") + _LOGGER.info("Closed the client.") finally: self._event_loop.stop() self._closing = False @@ -146,6 +161,89 @@ class AioServer(Server): def __init__(self, url: URL): self._url = url # Set the side of the transporter to server. + self._url.parameters[common_constants.SIDE_KEY] = common_constants.SERVER_VALUE + + # the event to indicate the close status of the server + self._event_loop = EventLoop() + self._event_loop.start() + + # Whether the server is exporting + self._exporting = False + # Whether the server is exported + self._exported = False + + # Whether the server is closing + self._closing = False + # Whether the server is closed + self._closed = False + + # start the server + self.export() + + def is_exported(self) -> bool: + return self._exported or self._exporting + + def is_closed(self) -> bool: + return self._closed or self._closing + + def export(self): + """ + Export the server. + """ + if self.is_exported(): + return + elif self.is_closed(): + raise RemotingError("The server is closed.") + + async def _inner_operation(_future: concurrent.futures.Future): + try: + running_loop = asyncio.get_running_loop() + server = await running_loop.create_server( + lambda: self._url.attributes[common_constants.PROTOCOL_KEY]( + self._url + ), + self._url.host, + self._url.port, + ) + + # Serve the server forever + async with server: + FutureHelper.set_result(_future, None) + await server.serve_forever() + except Exception as e: + FutureHelper.set_exception(_future, e) + + # Run the server logic in the event loop. + future = concurrent.futures.Future() + asyncio.run_coroutine_threadsafe( + _inner_operation(future), self._event_loop.loop + ) + + try: + exc = future.exception() + if exc: + raise RemotingError("Failed to export the server") from exc + else: + self._exported = True + _LOGGER.info("Exported the server. port: %s", self._url.port) + finally: + self._exporting = False + + def close(self): + """ + Close the server. + """ + if self.is_closed(): + return + self._closing = True + + try: + self._event_loop.stop() + self._closed = True + except Exception as e: + raise RemotingError("Failed to close the server") from e + finally: + self._closing = False class AioTransporter(Transporter): diff --git a/dubbo/remoting/aio/constants.py b/dubbo/remoting/aio/constants.py new file mode 100644 index 0000000..e26d52e --- /dev/null +++ b/dubbo/remoting/aio/constants.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["STREAM_HANDLER_KEY"] + +STREAM_HANDLER_KEY = "stream-handler" + +CLOSE_FUTURE_KEY = "close-future" diff --git a/dubbo/remoting/aio/event_loop.py b/dubbo/remoting/aio/event_loop.py index 26de787..5f0df4e 100644 --- a/dubbo/remoting/aio/event_loop.py +++ b/dubbo/remoting/aio/event_loop.py @@ -13,14 +13,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio import threading import uuid from typing import Optional -from dubbo.logger.logger_factory import loggerFactory +from dubbo.logger import loggerFactory -logger = loggerFactory.get_logger(__name__) +_LOGGER = loggerFactory.get_logger(__name__) def _try_use_uvloop() -> None: @@ -33,7 +34,7 @@ def _try_use_uvloop() -> None: # Check if the operating system. if os.name == "nt": # Windows is not supported. - logger.warning( + _LOGGER.warning( "Unable to use uvloop, because it is not supported on your operating system." ) return @@ -43,7 +44,7 @@ def _try_use_uvloop() -> None: import uvloop except ImportError: # uvloop is not available. - logger.warning( + _LOGGER.warning( "Unable to use uvloop, because it is not installed. " "You can install it by running `pip install uvloop`." ) diff --git a/dubbo/remoting/aio/exceptions.py b/dubbo/remoting/aio/exceptions.py index 4f3d1d6..f941615 100644 --- a/dubbo/remoting/aio/exceptions.py +++ b/dubbo/remoting/aio/exceptions.py @@ -15,16 +15,20 @@ # limitations under the License. -class RemotingException(RuntimeError): +class RemotingError(Exception): """ The base exception class for remoting. """ def __init__(self, message: str): super().__init__(message) + self.message = message + def __str__(self): + return self.message -class ProtocolException(RemotingException): + +class ProtocolError(RemotingError): """ The exception class for protocol errors. """ @@ -33,7 +37,7 @@ def __init__(self, message: str): super().__init__(message) -class StreamException(RemotingException): +class StreamError(RemotingError): """ The exception class for stream errors. """ diff --git a/dubbo/remoting/aio/http2/controllers.py b/dubbo/remoting/aio/http2/controllers.py index 0534bea..e7be817 100644 --- a/dubbo/remoting/aio/http2/controllers.py +++ b/dubbo/remoting/aio/http2/controllers.py @@ -13,179 +13,151 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import abc import asyncio import threading +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import Dict, Optional, Union +from typing import Dict, Optional, Set from h2.connection import H2Connection -from dubbo.logger.logger_factory import loggerFactory -from dubbo.remoting.aio.http2.frames import DataFrame, HeadersFrame, Http2Frame +from dubbo.common.utils import EventHelper +from dubbo.logger import loggerFactory +from dubbo.remoting.aio.http2.frames import ( + DataFrame, + HeadersFrame, + UserActionFrames, + WindowUpdateFrame, +) from dubbo.remoting.aio.http2.registries import Http2FrameType -from dubbo.remoting.aio.http2.stream import Http2Stream +from dubbo.remoting.aio.http2.stream import DefaultHttp2Stream, Http2Stream -logger = loggerFactory.get_logger(__name__) +__all__ = ["RemoteFlowController", "FrameInboundController", "FrameOutboundController"] +_LOGGER = loggerFactory.get_logger(__name__) -class FollowController: - """ - HTTP/2 stream flow controller. - Note: - This is a thread-unsafe class and must be used in the Http2Protocol class - - Args: - loop: The asyncio event loop. - h2_connection: The H2 connection. - transport: The asyncio transport. - """ - @dataclass - class StreamItem: - """ - The item for storing stream, flag, and event. - Args: - stream: The stream. - half_close: Whether to close the stream after sending the data. - event: This event is triggered when all data has been sent. - """ +class Controller(abc.ABC): + def __init__(self, loop: asyncio.AbstractEventLoop): + self._loop = loop + self._lock = threading.Lock() + self._task: Optional[asyncio.Task] = None + self._started = False + self._closed = False + + def start(self) -> None: + with self._lock: + if self._started: + return + self._task = self._loop.create_task(self._run()) + self._started = True + + @abc.abstractmethod + async def _run(self) -> None: + raise NotImplementedError() + + def close(self) -> None: + with self._lock: + if self._closed or not self._task: + return + self._task.cancel() + self._task = None + +class RemoteFlowController(Controller): + @dataclass + class Item: stream: Http2Stream - half_close: bool - event: asyncio.Event + data: bytearray + end_stream: bool + event: Optional[asyncio.Event] def __init__( self, - loop: asyncio.AbstractEventLoop, h2_connection: H2Connection, transport: asyncio.Transport, + loop: asyncio.AbstractEventLoop, ): - self._loop = loop + super().__init__(loop) self._h2_connection = h2_connection self._transport = transport - # Collection of all streams that need to send data - self._stream_dict: Dict[int, FollowController.StreamItem] = {} - - # Collection of streams that are currently sending data - self._outbound_stream_queue: asyncio.Queue[FollowController.StreamItem] = ( - asyncio.Queue() - ) - - # Collection of streams that are flow-controlled - self._follow_control_dict: Dict[int, FollowController.StreamItem] = {} - - # Actual storage for the data that needs to be sent - self._data_dict: Dict[int, bytearray] = {} - - # The task for sending data. - self._task = None - - def start(self) -> None: - """ - Start the data sender loop. - This creates and starts an asyncio task that runs the _data_sender_loop coroutine. - """ - self._task = self._loop.create_task(self._send_data()) - - def increment_flow_control_window(self, stream_id: Optional[int]) -> None: - """ - Increment the flow control window size. - Args: - stream_id: The stream identifier. If it is None, it means the entire connection. - """ + self._stream_dict: Dict[int, RemoteFlowController.Item] = {} + + self._outbound_queue: asyncio.Queue[int] = asyncio.Queue() + self._flow_controls: Set[int] = set() + + # Start the controller + self.start() + + def write_data( + self, stream: Http2Stream, frame: DataFrame, event: Optional[asyncio.Event] + ) -> None: + if stream.local_closed: + EventHelper.set(event) + _LOGGER.warning(f"Stream {stream.id} is closed.") + return + + item = self._stream_dict.get(stream.id) + if item: + # Extend the data if the stream item exists + item.data.extend(frame.data) + item.end_stream = frame.end_stream + # update the event + EventHelper.set(item.event) + item.event = event + else: + # Create a new stream item + item = RemoteFlowController.Item( + stream, bytearray(frame.data), frame.end_stream, event + ) + self._stream_dict[stream.id] = item + self._outbound_queue.put_nowait(stream.id) + + def release_flow_control(self, frame: WindowUpdateFrame) -> None: + stream_id = frame.stream_id if stream_id is None or stream_id == 0: # This is for the entire connection. - for item in self._follow_control_dict.values(): - self._outbound_stream_queue.put_nowait(item) - self._follow_control_dict = {} - elif stream_id in self._follow_control_dict: + for i in self._flow_controls: + self._outbound_queue.put_nowait(i) + self._flow_controls.clear() + elif stream_id in self._flow_controls: # This is specific to a single stream. - item = self._follow_control_dict.pop(stream_id) - self._outbound_stream_queue.put_nowait(item) - - def send_data( - self, - stream: Http2Stream, - data: bytes, - half_close: bool, - event: Union[asyncio.Event, threading.Event] = None, - ): - """ - Send data to the stream.(thread-unsafe) - Note: - Args: - stream: The stream. - data: The data to send. - half_close: Whether to close the stream after sending the data. - event: The event that is triggered when all data has been sent. - """ - - # Check if the stream is closed - if stream.is_local_closed(): - if event: - event.set() - logger.warning(f"Stream {stream.id} is closed. Ignoring data {data}") - else: - # Save the data to the data dictionary - if old_data := self._data_dict.get(stream.id): - old_data.extend(data) - item = self._stream_dict[stream.id] - item.half_close = half_close - # Update the event - if item.event: - item.event.set() - item.event = event - else: - self._data_dict[stream.id] = bytearray(data) - self._stream_dict[stream.id] = FollowController.StreamItem( - stream, half_close, event - ) - - # Put the stream into the outbound stream queue - self._outbound_stream_queue.put_nowait(self._stream_dict[stream.id]) + self._flow_controls.remove(stream_id) + self._outbound_queue.put_nowait(stream_id) - def stop(self) -> None: - """ - Stop the data sender loop. - This cancels the asyncio task that runs the _data_sender_loop coroutine. - """ - if self._task: - self._task.cancel() - - async def _send_data(self) -> None: - """ - Coroutine that continuously sends data frames from the outbound data queue while respecting flow control limits. - """ + async def _run(self) -> None: while True: # get the data to send.(async blocking) - item = await self._outbound_stream_queue.get() + stream_id = await self._outbound_queue.get() # check if the stream is closed + item = self._stream_dict[stream_id] stream = item.stream - if stream.is_local_closed(): + if stream.local_closed: # The local side of the stream is closed, so we don't need to send any data. - if item.event: - item.event.set() + EventHelper.set(item.event) continue # get the flow control window size - data = self._data_dict.get(stream.id, bytearray()) + data = item.data window_size = self._h2_connection.local_flow_control_window(stream.id) chunk_size = min(window_size, len(data)) data_to_send = data[:chunk_size] data_to_buffer = data[chunk_size:] # send the data - if data_to_send or item.half_close: + if data_to_send or item.end_stream: max_size = self._h2_connection.max_outbound_frame_size # Split the data into chunks and send them out for x in range(0, len(data_to_send), max_size): chunk = data_to_send[x : x + max_size] end_stream_flag = ( - item.half_close - and data_to_buffer == b"" - and x + max_size >= len(data_to_send) + item.end_stream + and not data_to_buffer + and (x + max_size >= len(data_to_send)) ) self._h2_connection.send_data( stream.id, chunk, end_stream=end_stream_flag @@ -201,148 +173,222 @@ async def _send_data(self) -> None: if data_to_buffer: # Save the data that could not be sent due to flow control limits - self._follow_control_dict[stream.id] = item - self._data_dict[stream.id] = data_to_buffer + item.data = data_to_buffer + self._flow_controls.add(stream.id) else: # If all data has been sent, trigger the event. - self._data_dict.pop(stream.id) - if item.event: - item.event.set() + self._stream_dict.pop(stream.id) + EventHelper.set(item.event) + if item.end_stream: + stream.close_local() -class FrameOrderController: +class FrameInboundController(Controller): """ - HTTP/2 frame writer. This class is responsible for writing frames in the correct order. - Note: - Some special frames do not need to be sorted through this queue, such as RST_STREAM, WINDOW_UPDATE, etc. - Args: - stream: The stream to which the frame belongs. - loop: The asyncio event loop. - protocol: The HTTP/2 protocol. + HTTP/2 frame inbound controller. + This class is responsible for reading frames in the correct order. """ - def __init__(self, stream: Http2Stream, loop: asyncio.AbstractEventLoop, protocol): + def __init__( + self, + stream: Http2Stream, + loop: asyncio.AbstractEventLoop, + protocol, + executor: Optional[ThreadPoolExecutor] = None, + ): + """ + Initialize the FrameInboundController. + :param stream: The stream. + :type stream: Http2Stream + :param loop: The asyncio event loop. + :type loop: asyncio.AbstractEventLoop + :param protocol: The HTTP/2 protocol. + :param executor: The thread pool executor for handling frames. + :type executor: Optional[ThreadPoolExecutor] + """ from dubbo.remoting.aio.http2.protocol import Http2Protocol - self._stream: Http2Stream = stream - self._loop: asyncio.AbstractEventLoop = loop + super().__init__(loop) + + self._stream = stream self._protocol: Http2Protocol = protocol + self._executor = executor - # The queue for writing frames. -> keep the order of frames - self._frame_queue: asyncio.PriorityQueue = asyncio.PriorityQueue() - # The task for writing frames. - self._send_frame_task: Optional[asyncio.Task] = None + # The queue for receiving frames. + self._inbound_queue: asyncio.Queue[UserActionFrames] = asyncio.Queue() - # some events - # This event is triggered when a HEADERS frame is placed in the queue. - self._start_event = asyncio.Event() - # This event is triggered when the headers are sent. - self._headers_sent_event: Optional[asyncio.Event] = None - # This event is triggered when the data is sent. - self._data_sent_event: Optional[asyncio.Event] = None + self._condition: asyncio.Condition = asyncio.Condition() - # The trailers frame. - self._trailers: Optional[HeadersFrame] = None + # Start the controller + self.start() - def start(self) -> None: + def write_frame(self, frame: UserActionFrames) -> None: """ - Start the frame writer loop. - This creates and starts an asyncio task that runs the _frame_writer_loop coroutine. + Put the frame into the frame queue (thread-unsafe). + :param frame: The HTTP/2 frame to put into the queue. """ - self._send_frame_task = self._loop.create_task(self._write_frame()) + self._inbound_queue.put_nowait(frame) - def write_headers(self, frame: HeadersFrame) -> None: + def ack_frame(self, frame: UserActionFrames) -> None: """ - Write the headers frame to the frame writer queue.(thread-safe) - Args: - frame: The headers frame. + Acknowledge the frame by setting the frame event.(thread-safe) """ - def _inner_operation(_frame: Http2Frame): - # put the frame into the queue - self._frame_queue.put_nowait((0, _frame)) - # trigger the start event - self._start_event.set() + async def _inner_operation(_frame: UserActionFrames): + async with self._condition: + if _frame.frame_type == Http2FrameType.DATA: + self._protocol.ack_received_data(_frame.stream_id, _frame.padding) + self._condition.notify_all() - self._loop.call_soon_threadsafe(_inner_operation, frame) + asyncio.run_coroutine_threadsafe(_inner_operation(frame), self._loop) - def write_data(self, frame: DataFrame, last: bool = False) -> None: + async def _run(self) -> None: """ - Write the data frame to the frame writer queue.(thread-safe) - Args: - frame: The data frame. - last: Unlike end_stream, this flag indicates whether the current frame is the last data frame or not. + Coroutine that continuously reads frames from the frame queue. """ + while True: + async with self._condition: + # get the frame from the queue + frame = await self._inbound_queue.get() + + if self._stream.remote_closed: + # The remote side of the stream is closed, so we don't need to process any more frames. + break + + # handle frame in the thread pool + self._loop.run_in_executor(self._executor, self._handle_frame, frame) + + if not frame.end_stream: + # Waiting for the previous frame to be processed + await self._condition.wait() + else: + # close the stream remotely + self._stream.close_remote() + break + + def _handle_frame(self, frame: UserActionFrames): + listener = self._stream.listener + # match the frame type + frame_type = frame.frame_type + if frame_type == Http2FrameType.HEADERS: + listener.on_headers(frame.headers, frame.end_stream) + elif frame_type == Http2FrameType.DATA: + listener.on_data(frame.data, frame.end_stream) + elif frame_type == Http2FrameType.RST_STREAM: + listener.cancel_by_remote(frame.error_code) + else: + _LOGGER.warning(f"unprocessed frame type: {frame.frame_type}") - def _inner_operation(_frame: Http2Frame, _last: bool): - # put the frame into the queue - self._frame_queue.put_nowait((1, _frame)) - if _last: - # put the trailers frame into the queue - if self._trailers: - self._frame_queue.put_nowait((2, self._trailers)) + # acknowledge the frame + self.ack_frame(frame) - self._loop.call_soon_threadsafe(_inner_operation, frame, last) - def write_trailers(self, frame: HeadersFrame) -> None: +class FrameOutboundController(Controller): + """ + HTTP/2 frame outbound controller. + This class is responsible for writing frames in the correct order. + """ + + LAST_DATA_FRAME = DataFrame(-1, b"", 0) + + def __init__( + self, stream: DefaultHttp2Stream, loop: asyncio.AbstractEventLoop, protocol + ): + from dubbo.remoting.aio.http2.protocol import Http2Protocol + + super().__init__(loop) + + self._stream = stream + self._protocol: Http2Protocol = protocol + + self._headers_put_event: asyncio.Event = asyncio.Event() + self._headers_sent_event: asyncio.Event = asyncio.Event() + self._headers: Optional[HeadersFrame] = None + + self._data_queue: asyncio.Queue[DataFrame] = asyncio.Queue() + self._data_sent_event: asyncio.Event = asyncio.Event() + + self._trailers: Optional[HeadersFrame] = None + + # Start the controller + self.start() + + def write_headers(self, frame: HeadersFrame) -> None: """ - Write the trailers frame to the frame writer queue.(thread-safe) - Note: - This method is suitable for cases where data frames are not to be sent - Args: - frame: The trailers frame. + Write the headers frame by order.(thread-safe) + :param frame: The headers frame. + :type frame: HeadersFrame """ - def _inner_operation(_frame: Http2Frame): - # put the frame into the queue - self._frame_queue.put_nowait((2, _frame)) + def _inner_operation(_frame: HeadersFrame): + if not self._headers: + # send the frame directly -> the headers frame is the first frame + self._headers = _frame + EventHelper.set(self._headers_put_event) + else: + # put the frame into the queue -> the headers frame is not the first frame(trailers) + self._trailers = _frame + # Notify the data queue that the last data frame has reached. + self._data_queue.put_nowait(FrameOutboundController.LAST_DATA_FRAME) self._loop.call_soon_threadsafe(_inner_operation, frame) - def write_trailers_after_data(self, frame: HeadersFrame) -> None: + def write_data(self, frame: DataFrame) -> None: """ - Write the trailers frame to the frame writer queue.(thread-safe) - Note: - This method is used to write trailers after the data frame. - If the data frame is not sent completely, the trailers frame will not be sent. + Write the data frame by order.(thread-safe) + :param frame: The data frame. + :type frame: DataFrame """ - self._trailers = frame + self._loop.call_soon_threadsafe(self._data_queue.put_nowait, frame) - async def _write_frame(self) -> None: + def write_rst(self, frame: UserActionFrames) -> None: """ - Coroutine that continuously writes frames from the frame queue. + Write the reset frame directly.(thread-safe) + :param frame: The reset frame. + :type frame: UserActionFrames """ - while True: - # wait for the start event - await self._start_event.wait() - # get the frame from the queue -> block & async - _, frame = await self._frame_queue.get() + def _inner_operation(_frame: UserActionFrames): + self._protocol.send_frame(_frame, self._stream) - # write the frame - if frame.frame_type == Http2FrameType.HEADERS: - self._headers_sent_event = self._protocol.write(frame, self._stream) - else: - # await the headers sent event - await self._headers_sent_event.wait() + self._stream.close_local() + self._stream.close_remote() + + self._loop.call_soon_threadsafe(_inner_operation, frame) + + async def _run(self) -> None: + """ + Coroutine that continuously writes frames from the frame queue. + """ + + # wait and send the headers frame + await self._headers_put_event.wait() + self._protocol.send_frame(self._headers, self._stream, self._headers_sent_event) - # await the data sent event - if self._data_sent_event: - await self._data_sent_event.wait() + # check if the headers frame is the last frame + if self._headers.end_stream: + self._stream.close_local() + return - self._data_sent_event = self._protocol.write(frame, self._stream) + # wait for the headers sent event + await self._headers_sent_event.wait() - # check if the frame is the last frame - if frame.end_stream: - # close the stream - if frame.frame_type != Http2FrameType.DATA: - self._stream.close_local() + # wait and send the data frames + while True: + frame = await self._data_queue.get() + if frame is not FrameOutboundController.LAST_DATA_FRAME: + self._data_sent_event = asyncio.Event() + self._protocol.send_frame(frame, self._stream, self._data_sent_event) + if frame.end_stream: + # The last frame has been sent, so the stream is closed. + return + else: + # The last frame has been reached. break - def stop(self) -> None: - """ - Stop the frame writer loop. - This cancels the asyncio task that runs the _frame_writer_loop coroutine. - """ - if self._send_frame_task: - self._send_frame_task.cancel() + # wait for the last data frame and send the trailers frame + await self._data_sent_event.wait() + self._protocol.send_frame(self._trailers, self._stream) + + # close the stream + self._stream.close_local() diff --git a/dubbo/remoting/aio/http2/frames.py b/dubbo/remoting/aio/http2/frames.py index 173e29b..2733b8d 100644 --- a/dubbo/remoting/aio/http2/frames.py +++ b/dubbo/remoting/aio/http2/frames.py @@ -13,11 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import time + +from typing import Union from dubbo.remoting.aio.http2.headers import Http2Headers from dubbo.remoting.aio.http2.registries import Http2ErrorCode, Http2FrameType +__all__ = [ + "Http2Frame", + "HeadersFrame", + "DataFrame", + "WindowUpdateFrame", + "ResetStreamFrame", + "UserActionFrames", +] + class Http2Frame: """ @@ -27,6 +37,8 @@ class Http2Frame: frame_type: The frame type. """ + __slots__ = ["stream_id", "frame_type", "end_stream", "timestamp"] + def __init__( self, stream_id: int, @@ -37,12 +49,6 @@ def __init__( self.frame_type = frame_type self.end_stream = end_stream - # The timestamp of the generated frame. -> comparison for Priority Queue - self.timestamp = int(round(time.time() * 1000)) - - def __lt__(self, other: "Http2Frame") -> bool: - return self.timestamp <= other.timestamp - def __repr__(self) -> str: return f"" @@ -56,6 +62,8 @@ class HeadersFrame(Http2Frame): end_stream: Whether the stream is ended. """ + __slots__ = ["headers"] + def __init__( self, stream_id: int, @@ -75,20 +83,22 @@ class DataFrame(Http2Frame): Args: stream_id: The stream identifier. data: The data to send. - data_length: The amount of data received that counts against the flow control window. + length: The amount of data received that counts against the flow control window. end_stream: Whether the stream """ + __slots__ = ["data", "padding"] + def __init__( self, stream_id: int, data: bytes, - data_length: int, + length: int, end_stream: bool = False, ): super().__init__(stream_id, Http2FrameType.DATA, end_stream) self.data = data - self.data_length = data_length + self.padding = length def __repr__(self) -> str: return f"" @@ -102,6 +112,8 @@ class WindowUpdateFrame(Http2Frame): delta: The number of bytes by which to increase the flow control window. """ + __slots__ = ["delta"] + def __init__( self, stream_id: int, @@ -122,6 +134,8 @@ class ResetStreamFrame(Http2Frame): error_code: The error code that indicates the reason for closing the stream. """ + __slots__ = ["error_code"] + def __init__( self, stream_id: int, @@ -132,3 +146,6 @@ def __init__( def __repr__(self) -> str: return f"" + + +UserActionFrames = Union[HeadersFrame, DataFrame, ResetStreamFrame] diff --git a/dubbo/remoting/aio/http2/headers.py b/dubbo/remoting/aio/http2/headers.py index 293248f..f50e314 100644 --- a/dubbo/remoting/aio/http2/headers.py +++ b/dubbo/remoting/aio/http2/headers.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import enum from collections import OrderedDict from typing import List, Optional, Tuple, Union @@ -32,10 +33,19 @@ class PseudoHeaderName(enum.Enum): # Response pseudo-headers STATUS = ":status" + @classmethod + def to_list(cls) -> List[str]: + """ + Get all pseudo-header names. + Returns: + The pseudo-header names list. + """ + return [header.value for header in cls] + -class MethodType(enum.Enum): +class HttpMethod(enum.Enum): """ - HTTP/2 method types. + HTTP method types. """ GET = "GET" @@ -54,50 +64,29 @@ class Http2Headers: HTTP/2 headers. """ + __slots__ = ["_headers"] + def __init__(self): self._headers: OrderedDict[str, Optional[str]] = OrderedDict() self._init() def _init(self): # keep the order of headers - self._headers[PseudoHeaderName.SCHEME.value] = None - self._headers[PseudoHeaderName.METHOD.value] = None - self._headers[PseudoHeaderName.AUTHORITY.value] = None - self._headers[PseudoHeaderName.PATH.value] = None - self._headers[PseudoHeaderName.STATUS.value] = None + self._headers = {name: "" for name in PseudoHeaderName.to_list()} def add(self, name: str, value: str) -> None: - """ - Add a header. - Args: - name: The header name. - value: The header value. - """ - self._headers[name] = value + self._headers[name] = str(value) - def get(self, name: str) -> Optional[str]: - """ - Get the header value. - Returns: - The header value: If the header exists, return the value. Otherwise, return None. - """ - return self._headers.get(name, None) + def get(self, name: str, default: Optional[str] = None) -> Optional[str]: + return self._headers.get(name, default) @property def method(self) -> Optional[str]: - """ - Get the method. - """ return self.get(PseudoHeaderName.METHOD.value) @method.setter - def method(self, value: Union[MethodType, str]) -> None: - """ - Set the method. - Args: - value: The method value. - """ - if isinstance(value, MethodType): + def method(self, value: Union[HttpMethod, str]) -> None: + if isinstance(value, HttpMethod): value = value.value else: value = value.upper() @@ -105,77 +94,61 @@ def method(self, value: Union[MethodType, str]) -> None: @property def scheme(self) -> Optional[str]: - """ - Get the scheme. - """ return self.get(PseudoHeaderName.SCHEME.value) @scheme.setter def scheme(self, value: str) -> None: - """ - Set the scheme. - Args: - value: The scheme value. - """ self.add(PseudoHeaderName.SCHEME.value, value) @property def authority(self) -> Optional[str]: - """ - Get the authority. - """ return self.get(PseudoHeaderName.AUTHORITY.value) @authority.setter def authority(self, value: str) -> None: - """ - Set the authority. - Args: - value: The authority value. - """ self.add(PseudoHeaderName.AUTHORITY.value, value) @property def path(self) -> Optional[str]: - """ - Get the path. - """ return self.get(PseudoHeaderName.PATH.value) @path.setter def path(self, value: str) -> None: - """ - Set the path. - Args: - value: The path value. - """ self.add(PseudoHeaderName.PATH.value, value) @property def status(self) -> Optional[str]: - """ - Get the status code. - """ return self.get(PseudoHeaderName.STATUS.value) @status.setter def status(self, value: str) -> None: - """ - Set the status code. - Args: - value: The status code. - """ self.add(PseudoHeaderName.STATUS.value, value) def to_list(self) -> List[Tuple[str, str]]: """ Convert the headers to a list. The list contains all non-None headers. - Returns: - The headers list. + :return: The headers list. + :rtype: List[Tuple[str, str]] + """ + headers = [] + pseudo_headers = PseudoHeaderName.to_list() + for name, value in list(self._headers.items()): + if name in pseudo_headers and value == "": + continue + headers.append((str(name), str(value) or "")) + return headers + + def to_dict(self) -> OrderedDict[str, str]: + """ + Convert the headers to an ordered dict. + :return: The headers' dict. + :rtype: OrderedDict[str, Optional[str]] """ - return [ - (name, value) for name, value in self._headers.items() if value is not None - ] + headers_dict = OrderedDict() + for key, value in self._headers.items(): + if value is not None and value != "": + headers_dict[key] = value + return headers_dict def __repr__(self) -> str: return f"" @@ -190,6 +163,5 @@ def from_list(cls, headers: List[Tuple[str, str]]) -> "Http2Headers": The Http2Headers object. """ http2_headers = cls() - for name, value in headers: - http2_headers.add(name, value) + http2_headers._headers = dict(headers) return http2_headers diff --git a/dubbo/remoting/aio/http2/protocol.py b/dubbo/remoting/aio/http2/protocol.py index e42bb9b..7276412 100644 --- a/dubbo/remoting/aio/http2/protocol.py +++ b/dubbo/remoting/aio/http2/protocol.py @@ -13,27 +13,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio -import concurrent -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple from h2.config import H2Configuration from h2.connection import H2Connection -from dubbo.constants import common_constants -from dubbo.logger.logger_factory import loggerFactory -from dubbo.remoting.aio.exceptions import ProtocolException -from dubbo.remoting.aio.http2.controllers import FollowController -from dubbo.remoting.aio.http2.frames import Http2Frame +from dubbo.common import constants as common_constants +from dubbo.common.url import URL +from dubbo.common.utils import EventHelper, FutureHelper +from dubbo.logger import loggerFactory +from dubbo.remoting.aio import constants as h2_constants +from dubbo.remoting.aio.exceptions import ProtocolError +from dubbo.remoting.aio.http2.controllers import RemoteFlowController +from dubbo.remoting.aio.http2.frames import UserActionFrames from dubbo.remoting.aio.http2.registries import Http2FrameType from dubbo.remoting.aio.http2.stream import Http2Stream from dubbo.remoting.aio.http2.utils import Http2EventUtils -from dubbo.url import URL -logger = loggerFactory.get_logger(__name__) +_LOGGER = loggerFactory.get_logger(__name__) + +__all__ = ["Http2Protocol"] class Http2Protocol(asyncio.Protocol): + """ + HTTP/2 protocol implementation. + """ + + __slots__ = [ + "_url", + "_loop", + "_h2_connection", + "_transport", + "_flow_controller", + "_stream_handler", + ] def __init__(self, url: URL): self._url = url @@ -41,8 +57,8 @@ def __init__(self, url: URL): # Create the H2 state machine side_client = ( - self._url.get_parameter(common_constants.TRANSPORTER_SIDE_KEY) - == common_constants.TRANSPORTER_SIDE_CLIENT + self._url.parameters.get(common_constants.SIDE_KEY) + == common_constants.CLIENT_VALUE ) h2_config = H2Configuration(client_side=side_client, header_encoding="utf-8") self._h2_connection: H2Connection = H2Connection(config=h2_config) @@ -50,11 +66,9 @@ def __init__(self, url: URL): # The transport instance self._transport: Optional[asyncio.Transport] = None - self._follow_controller: Optional[FollowController] = None + self._flow_controller: Optional[RemoteFlowController] = None - self._stream_handler = self._url.attributes[ - common_constants.TRANSPORTER_STREAM_HANDLER_KEY - ] + self._stream_handler = self._url.attributes[h2_constants.STREAM_HANDLER_KEY] def connection_made(self, transport: asyncio.Transport): """ @@ -69,47 +83,32 @@ def connection_made(self, transport: asyncio.Transport): self._transport.write(self._h2_connection.data_to_send()) # Create and start the follow controller - self._follow_controller = FollowController( - self._loop, self._h2_connection, self._transport + self._flow_controller = RemoteFlowController( + self._h2_connection, self._transport, self._loop ) - self._follow_controller.start() # Initialize the stream handler self._stream_handler.do_init(self._loop, self) - # Notify the connection is established - if event := self._url.attributes.get("connect-event"): - event.set() - - def get_next_stream_id( - self, future: Union[asyncio.Future, concurrent.futures.Future] - ) -> None: + def get_next_stream_id(self, future) -> None: """ Create a new stream.(thread-safe) Args: future: The future to set the stream identifier. """ - def _inner_operation(_future: Union[asyncio.Future, concurrent.futures.Future]): + def _inner_operation(_future): stream_id = self._h2_connection.get_next_available_stream_id() - _future.set_result(stream_id) + FutureHelper.set_result(_future, stream_id) self._loop.call_soon_threadsafe(_inner_operation, future) - def write(self, frame: Http2Frame, stream: Http2Stream) -> asyncio.Event: - """ - Send the HTTP/2 frame.(thread-safe) - Args: - frame: The HTTP/2 frame. - stream: The HTTP/2 stream. - Returns: - The event to be set after sending the frame. - """ - event = asyncio.Event() - self._loop.call_soon_threadsafe(self._send_frame, frame, stream, event) - return event - - def _send_frame(self, frame: Http2Frame, stream: Http2Stream, event: asyncio.Event): + def send_frame( + self, + frame: UserActionFrames, + stream: Http2Stream, + event: Optional[asyncio.Event] = None, + ): """ Send the HTTP/2 frame.(thread-unsafe) Args: @@ -123,13 +122,11 @@ def _send_frame(self, frame: Http2Frame, stream: Http2Stream, event: asyncio.Eve frame.stream_id, frame.headers.to_list(), frame.end_stream, event ) elif frame_type == Http2FrameType.DATA: - self._follow_controller.send_data( - stream, frame.data, frame.end_stream, event - ) + self._flow_controller.write_data(stream, frame, event) elif frame_type == Http2FrameType.RST_STREAM: self._send_reset_frame(frame.stream_id, frame.error_code.value, event) else: - logger.warning(f"Unhandled frame: {frame}") + _LOGGER.warning(f"Unhandled frame: {frame}") def _send_headers_frame( self, @@ -148,8 +145,7 @@ def _send_headers_frame( """ self._h2_connection.send_headers(stream_id, headers, end_stream=end_stream) self._transport.write(self._h2_connection.data_to_send()) - if event: - event.set() + EventHelper.set(event) def _send_reset_frame( self, stream_id: int, error_code: int, event: Optional[asyncio.Event] = None @@ -163,8 +159,7 @@ def _send_reset_frame( """ self._h2_connection.reset_stream(stream_id, error_code) self._transport.write(self._h2_connection.data_to_send()) - if event: - event.set() + EventHelper.set(event) def data_received(self, data): events = self._h2_connection.receive_data(data) @@ -175,9 +170,7 @@ def data_received(self, data): if frame is not None: if frame.frame_type == Http2FrameType.WINDOW_UPDATE: # Because flow control may be at the connection level, it is handled here - self._follow_controller.increment_flow_control_window( - frame.stream_id - ) + self._flow_controller.release_flow_control(frame) else: self._stream_handler.handle_frame(frame) @@ -185,11 +178,23 @@ def data_received(self, data): # 1. Events that are handled automatically by the H2 library (e.g. RemoteSettingsChanged, PingReceived). # -> We just need to send it. # 2. Events that are not implemented or do not require attention. -> We'll ignore it for now. - if outbound_data := self._h2_connection.data_to_send(): + outbound_data = self._h2_connection.data_to_send() + if outbound_data: self._transport.write(outbound_data) except Exception as e: - raise ProtocolException("Failed to process the Http/2 event.") from e + raise ProtocolError("Failed to process the Http/2 event.") from e + + def ack_received_data(self, stream_id: int, padding: int): + """ + Acknowledge the received data. + Args: + stream_id: The stream identifier. + padding: The amount of data received that counts against the flow control window. + """ + + self._h2_connection.acknowledge_received_data(padding, stream_id) + self._transport.write(self._h2_connection.data_to_send()) def close(self): """ @@ -204,10 +209,11 @@ def connection_lost(self, exc): """ Called when the connection is lost. """ - self._follow_controller.stop() + self._flow_controller.close() # Notify the connection is established - if future := self._url.attributes.get("close-future"): + future = self._url.attributes.get(h2_constants.CLOSE_FUTURE_KEY) + if future: if exc: - future.set_exception(exc) + FutureHelper.set_exception(future, exc) else: - future.set_result(None) + FutureHelper.set_result(future, None) diff --git a/dubbo/remoting/aio/http2/registries.py b/dubbo/remoting/aio/http2/registries.py index 69ac023..fd07bf2 100644 --- a/dubbo/remoting/aio/http2/registries.py +++ b/dubbo/remoting/aio/http2/registries.py @@ -13,9 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import enum from typing import Optional +__all__ = ["Http2FrameType", "Http2ErrorCode", "Http2Settings", "HttpStatus"] + class Http2FrameType(enum.Enum): """ diff --git a/dubbo/remoting/aio/http2/stream.py b/dubbo/remoting/aio/http2/stream.py index da6ee4a..3124bab 100644 --- a/dubbo/remoting/aio/http2/stream.py +++ b/dubbo/remoting/aio/http2/stream.py @@ -13,266 +13,260 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import abc import asyncio +from concurrent.futures import ThreadPoolExecutor from typing import Optional -from dubbo.remoting.aio.exceptions import StreamException +from dubbo.remoting.aio.exceptions import StreamError from dubbo.remoting.aio.http2.frames import ( DataFrame, HeadersFrame, - Http2Frame, ResetStreamFrame, + UserActionFrames, ) from dubbo.remoting.aio.http2.headers import Http2Headers -from dubbo.remoting.aio.http2.registries import Http2ErrorCode, Http2FrameType +from dubbo.remoting.aio.http2.registries import Http2ErrorCode + +__all__ = ["Http2Stream", "DefaultHttp2Stream"] -class Http2Stream: +class Http2Stream(abc.ABC): """ A "stream" is an independent, bidirectional sequence of frames exchanged between the client and server within an HTTP/2 connection. see: https://datatracker.ietf.org/doc/html/rfc7540#section-5 - Args: - stream_id: The stream identifier. - listener: The stream listener. - loop: The asyncio event loop. - protocol: The HTTP/2 protocol. """ - def __init__( - self, - stream_id: int, - listener: "StreamListener", - loop: asyncio.AbstractEventLoop, - protocol, - ): - from dubbo.remoting.aio.http2.controllers import FrameOrderController - from dubbo.remoting.aio.http2.protocol import Http2Protocol - - self._loop: asyncio.AbstractEventLoop = loop - self._protocol: Http2Protocol = protocol + __slots__ = ["_id", "_listener", "_local_closed", "_remote_closed"] - # The stream identifier. + def __init__(self, stream_id: int, listener: "Http2Stream.Listener"): self._id = stream_id self._listener = listener + self._listener.bind(self) - # The frame order controller. - self._frame_order_controller: FrameOrderController = FrameOrderController( - self, self._loop, self._protocol - ) - self._frame_order_controller.start() - - # Whether the headers have been sent. - self._headers_sent = False - # Whether the headers have been received. - self._headers_received = False - - # Indicates whether the frame identified with end_stream was written (and may not have been sent yet). - self._end_stream = False - - # Whether the stream is closed locally or remotely. + # Whether the stream is closed locally. -> it means the stream can't send any more frames. self._local_closed = False + # Whether the stream is closed remotely. -> it means the stream can't receive any more frames. self._remote_closed = False @property def id(self) -> int: + """ + Get the stream identifier. + """ return self._id - def is_headers_sent(self) -> bool: - return self._headers_sent + @property + def listener(self) -> "Http2Stream.Listener": + """ + Get the listener. + """ + return self._listener - def is_local_closed(self) -> bool: + @property + def local_closed(self) -> bool: """ Check if the stream is closed locally. """ return self._local_closed - def close_local(self) -> None: + @property + def remote_closed(self) -> bool: """ - Close the stream locally. + Check if the stream is closed remotely. """ - self._local_closed = True - self._frame_order_controller.stop() + return self._remote_closed - def is_remote_closed(self) -> bool: + def close_local(self) -> None: """ - Check if the stream is closed remotely. + Close the stream locally. """ - return self._remote_closed + if self._local_closed: + return + self._local_closed = True def close_remote(self) -> None: """ Close the stream remotely. """ + if self._remote_closed: + return self._remote_closed = True - def _send_available(self): + @abc.abstractmethod + def send_headers(self, headers: Http2Headers, end_stream: bool = False) -> None: """ - Check if the stream is available for sending frames. + Send the headers. + :param headers: The HTTP/2 headers. + The second send of headers will be treated as trailers (end_stream must be True). + :type headers: Http2Headers + :param end_stream: Whether to close the stream after sending the data. """ - return not self.is_local_closed() and not self._end_stream + raise NotImplementedError() - def send_headers(self, headers: Http2Headers, end_stream: bool = False) -> None: + @abc.abstractmethod + def send_data(self, data: bytes, end_stream: bool = False) -> None: """ - Send the headers.(thread-unsafe) - Args: - headers: The HTTP/2 headers. - end_stream: Whether to close the stream after sending the data. + Send the data. + :param data: The data to send. + :type data: bytes + :param end_stream: Whether to close the stream after sending the data. """ - if self.is_headers_sent(): - raise StreamException("Headers have been sent.") - elif not self._send_available(): - raise StreamException( - "The stream cannot send a frame because it has been closed." - ) + raise NotImplementedError() - headers_frame = HeadersFrame(self.id, headers, end_stream=end_stream) - self._end_stream = end_stream - self._frame_order_controller.write_headers(headers_frame) - - self._headers_sent = True - - def send_data( - self, data: bytes, end_stream: bool = False, last: bool = False - ) -> None: + @abc.abstractmethod + def cancel_by_local(self, error_code: Http2ErrorCode) -> None: """ - Send the data.(thread-unsafe) - Args: - data: The data to send. - end_stream: Whether to close the stream after sending the data. - last: Is it the last data frame? + Cancel the stream locally. -> send RST_STREAM frame. + :param error_code: The error code. + :type error_code: Http2ErrorCode """ - if not self.is_headers_sent(): - raise StreamException("Headers have not been sent.") - elif not self._send_available(): - raise StreamException( - "The stream cannot send a frame because it has been closed." - ) + raise NotImplementedError() - data_frame = DataFrame(self.id, data, len(data), end_stream=end_stream) - self._end_stream = end_stream - self._frame_order_controller.write_data(data_frame, last) - - def send_trailers(self, headers: Http2Headers, send_data: bool) -> None: + class Listener(abc.ABC): """ - Send trailers with the given headers. Optionally, indicate if data frames - need to be sent. + Http2StreamListener is a base class for handling events in an HTTP/2 stream. - Args: - headers: The HTTP/2 headers to be sent as trailers. - send_data: A flag indicating whether data frames need to be sent. + This class provides a set of callback methods that are called when specific + events occur on the stream, such as receiving headers, receiving data, or + resetting the stream. To use this class, create a subclass and implement the + callback methods for the events you want to handle. """ - if not self.is_headers_sent(): - raise StreamException("Headers have not been sent.") - elif not self._send_available(): - raise StreamException( - "The stream cannot send a frame because it has been closed." - ) - trailers_frame = HeadersFrame(self.id, headers, end_stream=True) - self._end_stream = True - if send_data: - self._frame_order_controller.write_trailers_after_data(trailers_frame) - else: - self._frame_order_controller.write_trailers(trailers_frame) + __slots__ = ["_stream"] + + def __init__(self): + self._stream: Optional["Http2Stream"] = None + + def bind(self, stream: "Http2Stream") -> None: + """ + Bind the stream to the listener. + :param stream: The stream to bind. + :type stream: Http2Stream + """ + self._stream = stream + + @property + def stream(self) -> "Http2Stream": + """ + Get the stream. + """ + return self._stream + + @abc.abstractmethod + def on_headers(self, headers: Http2Headers, end_stream: bool) -> None: + """ + Called when the headers are received. + :param headers: The HTTP/2 headers. + :type headers: Http2Headers + :param end_stream: Whether the stream is closed after receiving the headers. + :type end_stream: bool + """ + raise NotImplementedError() + + @abc.abstractmethod + def on_data(self, data: bytes, end_stream: bool) -> None: + """ + Called when the data is received. + :param data: The data. + :type data: bytes + :param end_stream: Whether the stream is closed after receiving the data. + """ + raise NotImplementedError() + + @abc.abstractmethod + def cancel_by_remote(self, error_code: Http2ErrorCode) -> None: + """ + Cancel the stream remotely. + :param error_code: The error code. + :type error_code: Http2ErrorCode + """ + raise NotImplementedError() + + +class DefaultHttp2Stream(Http2Stream): + """ + Default implementation of the Http2Stream. + """ - def send_reset(self, error_code: Http2ErrorCode) -> None: - """ - Send the reset frame.(thread-unsafe) - Args: - error_code: The error code. - """ - if self.is_local_closed(): - raise StreamException("The stream has been reset.") + __slots__ = [ + "_loop", + "_protocol", + "_inbound_controller", + "_outbound_controller", + "_headers_sent", + ] - reset_frame = ResetStreamFrame(self.id, error_code) - # It's a special frame, no need to queue, just send it - self._protocol.write(reset_frame, self) - # close the stream locally and remotely - self.close_local() - self.close_remote() + def __init__( + self, + stream_id: int, + listener: "Http2Stream.Listener", + loop: asyncio.AbstractEventLoop, + protocol, + executor: Optional[ThreadPoolExecutor] = None, + ): + # Avoid circular import + from dubbo.remoting.aio.http2.controllers import ( + FrameInboundController, + FrameOutboundController, + ) - def receive_frame(self, frame: Http2Frame) -> None: - """ - Receive a frame from the stream. - Args: - frame: The frame to be received. - """ - if self.is_remote_closed(): - # The stream is closed remotely, ignore the frame - return + super().__init__(stream_id, listener) + self._loop = loop + self._protocol = protocol - if frame.end_stream: - # received end_stream frame, close the stream remotely - self.close_remote() - - frame_type = frame.frame_type - if frame_type == Http2FrameType.HEADERS: - if not self._headers_received: - # HEADERS frame - self._headers_received = True - self._listener.on_headers(frame.headers, frame.end_stream) - else: - # TRAILERS frame - self._listener.on_trailers(frame.headers) - elif frame_type == Http2FrameType.DATA: - self._listener.on_data(frame.data, frame.end_stream) - elif frame_type == Http2FrameType.RST_STREAM: - self._listener.on_reset(frame.error_code) - self.close_local() - - -class StreamListener: - """ - Http2StreamListener is a base class for handling events in an HTTP/2 stream. + # steam inbound controller + self._inbound_controller: FrameInboundController = FrameInboundController( + self, self._loop, self._protocol, executor + ) + # steam outbound controller + self._outbound_controller: FrameOutboundController = FrameOutboundController( + self, self._loop, self._protocol + ) - This class provides a set of callback methods that are called when specific - events occur on the stream, such as receiving headers, receiving data, or - resetting the stream. To use this class, create a subclass and implement the - callback methods for the events you want to handle. - """ + # The flag to indicate whether the headers have been sent. + self._headers_sent = False - def __init__(self): - self._stream: Optional[Http2Stream] = None + def close_local(self) -> None: + super().close_local() + self._outbound_controller.close() - def bind(self, stream: Http2Stream) -> None: - """ - Bind the stream to the listener. - Args: - stream: The stream. - """ - self._stream = stream + def close_remote(self) -> None: + super().close_remote() + self._inbound_controller.close() - def on_headers(self, headers: Http2Headers, end_stream: bool) -> None: - """ - Called when the headers are received. - Args: - headers: The HTTP/2 headers. - end_stream: Whether the stream is closed after receiving the headers. - """ - raise NotImplementedError("on_headers() is not implemented.") + def send_headers(self, headers: Http2Headers, end_stream: bool = False) -> None: + if self.local_closed: + raise StreamError("The stream has been closed locally.") + elif self._headers_sent and not end_stream: + raise StreamError( + "Trailers must be the last frame of the stream(end_stream must be True)." + ) - def on_data(self, data: bytes, end_stream: bool) -> None: - """ - Called when the data is received. - Args: - data: The data. - end_stream: Whether the stream is closed after receiving the data. - """ - raise NotImplementedError("on_data() is not implemented.") + self._headers_sent = True + headers_frame = HeadersFrame(self.id, headers, end_stream=end_stream) + self._outbound_controller.write_headers(headers_frame) - def on_trailers(self, headers: Http2Headers) -> None: - """ - Called when the trailers are received. - Args: - headers: The HTTP/2 headers. - """ - raise NotImplementedError("on_trailers() is not implemented.") + def send_data(self, data: bytes, end_stream: bool = False) -> None: + if self.local_closed: + raise StreamError("The stream has been closed locally.") + elif not self._headers_sent: + raise StreamError("Headers have not been sent.") + data_frame = DataFrame(self.id, data, len(data), end_stream=end_stream) + self._outbound_controller.write_data(data_frame) + + def cancel_by_local(self, error_code: Http2ErrorCode) -> None: + if self.local_closed: + raise StreamError("The stream has been closed locally.") + reset_frame = ResetStreamFrame(self.id, error_code) + self._outbound_controller.write_rst(reset_frame) - def on_reset(self, error_code: Http2ErrorCode) -> None: + def receive_frame(self, frame: UserActionFrames) -> None: """ - Called when the stream is reset. - Args: - error_code: The error code. + Receive the frame. + :param frame: The frame to receive. + :type frame: UserActionFrames """ - raise NotImplementedError("on_reset() is not implemented.") + self._inbound_controller.write_frame(frame) diff --git a/dubbo/remoting/aio/http2/stream_handler.py b/dubbo/remoting/aio/http2/stream_handler.py index b6e7a3e..dfea951 100644 --- a/dubbo/remoting/aio/http2/stream_handler.py +++ b/dubbo/remoting/aio/http2/stream_handler.py @@ -13,17 +13,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio from concurrent import futures -from typing import Dict, Optional +from typing import Callable, Dict, Optional -from dubbo.logger.logger_factory import loggerFactory -from dubbo.remoting.aio.exceptions import ProtocolException -from dubbo.remoting.aio.http2.frames import Http2Frame +from dubbo.logger import loggerFactory +from dubbo.remoting.aio.exceptions import ProtocolError +from dubbo.remoting.aio.http2.frames import UserActionFrames from dubbo.remoting.aio.http2.registries import Http2FrameType -from dubbo.remoting.aio.http2.stream import Http2Stream, StreamListener +from dubbo.remoting.aio.http2.stream import DefaultHttp2Stream, Http2Stream + +_LOGGER = loggerFactory.get_logger(__name__) -logger = loggerFactory.get_logger(__name__) +_all__ = [ + "StreamMultiplexHandler", + "StreamClientMultiplexHandler", + "StreamServerMultiplexHandler", +] class StreamMultiplexHandler: @@ -31,6 +38,8 @@ class StreamMultiplexHandler: The StreamMultiplexHandler class is responsible for managing the HTTP/2 streams. """ + __slots__ = ["_loop", "_protocol", "_streams", "_executor"] + def __init__(self, executor: Optional[futures.ThreadPoolExecutor] = None): # Import the Http2Protocol class here to avoid circular imports. from dubbo.remoting.aio.http2.protocol import Http2Protocol @@ -39,7 +48,7 @@ def __init__(self, executor: Optional[futures.ThreadPoolExecutor] = None): self._protocol: Optional[Http2Protocol] = None # The map of stream_id to stream. - self._streams: Optional[Dict[int, Http2Stream]] = None + self._streams: Optional[Dict[int, DefaultHttp2Stream]] = None # The executor for handling received frames. self._executor = executor @@ -55,7 +64,7 @@ def do_init(self, loop: asyncio.AbstractEventLoop, protocol) -> None: self._protocol = protocol self._streams = {} - def put_stream(self, stream_id: int, stream: Http2Stream) -> None: + def put_stream(self, stream_id: int, stream: DefaultHttp2Stream) -> None: """ Put the stream into the stream map. Args: @@ -64,7 +73,7 @@ def put_stream(self, stream_id: int, stream: Http2Stream) -> None: """ self._streams[stream_id] = stream - def get_stream(self, stream_id: int) -> Optional[Http2Stream]: + def get_stream(self, stream_id: int) -> Optional[DefaultHttp2Stream]: """ Get the stream by stream identifier. Args: @@ -82,28 +91,22 @@ def remove_stream(self, stream_id: int) -> None: """ self._streams.pop(stream_id, None) - def handle_frame(self, frame: Http2Frame) -> None: + def handle_frame(self, frame: UserActionFrames) -> None: """ Handle the HTTP/2 frame. Args: frame: The HTTP/2 frame. """ - if stream := self._streams.get(frame.stream_id): - # Handle the frame in the executor. - self._handle_frame_in_executor(stream, frame) + stream = self._streams.get(frame.stream_id) + if stream: + # It must be ensured that the event loop is not blocked, + # and if there is a blocking operation, the executor must be used. + stream.receive_frame(frame) else: - logger.warning( + _LOGGER.warning( f"Stream {frame.stream_id} not found. Ignoring frame {frame}" ) - def _handle_frame_in_executor(self, stream: Http2Stream, frame: Http2Frame) -> None: - """ - Handle the HTTP/2 frame in the executor. - Args: - frame: The HTTP/2 frame. - """ - self._loop.run_in_executor(self._executor, stream.receive_frame, frame) - def destroy(self) -> None: """ Destroy the StreamMultiplexHandler. @@ -118,24 +121,27 @@ class StreamClientMultiplexHandler(StreamMultiplexHandler): The StreamClientMultiplexHandler class is responsible for managing the HTTP/2 streams on the client side. """ - def create(self, listener: StreamListener) -> Http2Stream: + def create(self, listener: Http2Stream.Listener) -> DefaultHttp2Stream: """ Create a new stream. - Returns: - The created stream. + :param listener: The stream listener. + :type listener: Http2Stream.Listener + :return: The stream. + :rtype: DefaultHttp2Stream """ future = futures.Future() self._protocol.get_next_stream_id(future) try: # block until the stream_id is created stream_id = future.result() - self._streams[stream_id] = Http2Stream( - stream_id, listener, self._loop, self._protocol + new_stream = DefaultHttp2Stream( + stream_id, listener, self._loop, self._protocol, self._executor ) + self.put_stream(stream_id, new_stream) except Exception as e: - raise ProtocolException("Failed to create stream.") from e + raise ProtocolError("Failed to create stream.") from e - return self._streams[stream_id] + return new_stream class StreamServerMultiplexHandler(StreamMultiplexHandler): @@ -143,23 +149,35 @@ class StreamServerMultiplexHandler(StreamMultiplexHandler): The StreamServerMultiplexHandler class is responsible for managing the HTTP/2 streams on the server side. """ - def register(self, stream_id: int) -> Http2Stream: + __slots__ = ["_listener_factory"] + + def __init__( + self, + listener_factory: Callable[[], Http2Stream.Listener], + executor: Optional[futures.ThreadPoolExecutor] = None, + ): + super().__init__(executor) + self._listener_factory = listener_factory + + def register(self, stream_id: int) -> DefaultHttp2Stream: """ Register the stream. - Args: - stream_id: The stream identifier. - Returns: - The created stream. + :param stream_id: The stream identifier. + :type stream_id: int + :return: The stream. + :rtype: DefaultHttp2Stream """ - stream = Http2Stream(stream_id, StreamListener(), self._loop, self._protocol) - self._streams[stream_id] = stream - return stream + stream_listener = self._listener_factory() + new_stream = DefaultHttp2Stream( + stream_id, stream_listener, self._loop, self._protocol, self._executor + ) + self.put_stream(stream_id, new_stream) + return new_stream - def handle_frame(self, frame: Http2Frame) -> None: + def handle_frame(self, frame: UserActionFrames) -> None: """ Handle the HTTP/2 frame. - Args: - frame: The HTTP/2 frame. + :param frame: The HTTP/2 frame. """ # Register the stream if the frame is a HEADERS frame. if frame.frame_type == Http2FrameType.HEADERS: diff --git a/dubbo/remoting/aio/http2/utils.py b/dubbo/remoting/aio/http2/utils.py index 8ecb18f..4de376e 100644 --- a/dubbo/remoting/aio/http2/utils.py +++ b/dubbo/remoting/aio/http2/utils.py @@ -13,20 +13,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional + +from typing import Union import h2.events as h2_event from dubbo.remoting.aio.http2.frames import ( DataFrame, HeadersFrame, - Http2Frame, ResetStreamFrame, WindowUpdateFrame, ) from dubbo.remoting.aio.http2.headers import Http2Headers from dubbo.remoting.aio.http2.registries import Http2ErrorCode +__all__ = ["Http2EventUtils"] + class Http2EventUtils: """ @@ -34,7 +36,9 @@ class Http2EventUtils: """ @staticmethod - def convert_to_frame(event: h2_event.Event) -> Optional[Http2Frame]: + def convert_to_frame( + event: h2_event.Event, + ) -> Union[HeadersFrame, DataFrame, ResetStreamFrame, WindowUpdateFrame, None]: """ Convert a h2.events.Event to HTTP/2 Frame. Args: diff --git a/dubbo/serialization.py b/dubbo/serialization.py deleted file mode 100644 index 0a5baa5..0000000 --- a/dubbo/serialization.py +++ /dev/null @@ -1,87 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Optional - -from dubbo.constants.type_constants import DeserializingFunction, SerializingFunction -from dubbo.logger.logger_factory import loggerFactory - -logger = loggerFactory.get_logger(__name__) - - -class Serialization: - """ - Serialization class - Args: - serializing_function(SerializingFunction): The serialization function - deserializing_function(DeserializingFunction): The deserialization function - """ - - def __init__( - self, - serializing_function: Optional[SerializingFunction] = None, - deserializing_function: Optional[DeserializingFunction] = None, - ): - self.serializing_function = serializing_function - self.deserializing_function = deserializing_function - - def serialize(self, *args, **kwargs) -> bytes: - """ - Serialize the given data - Args: - *args: Variable length argument list - **kwargs: Arbitrary keyword arguments - Returns: - bytes: The serialized data - Exception: If the serialization fails - """ - # serialize the data - if self.serializing_function: - try: - return self.serializing_function(*args, **kwargs) - except Exception as e: - logger.exception( - "Serialization send error, please check the incoming serialization function" - ) - raise e - else: - # check if the data is bytes -> args[0] - if isinstance(args[0], bytes): - return args[0] - else: - err_msg = "The args[0] is not bytes, you should pass parameters of type bytes, or set the serialization function" - logger.error(err_msg) - raise ValueError(err_msg) - - def deserialize(self, data: bytes) -> Any: - """ - Deserialize the given data - Args: - data(bytes): The data to deserialize - Returns: - Any: The deserialized data - Exception: If the deserialization fails - """ - # deserialize the data - if not self.deserializing_function: - return data - else: - try: - return self.deserializing_function(data) - except Exception as e: - logger.exception( - "Deserialization send error, please check the incoming deserialization function" - ) - raise e diff --git a/dubbo/serialization/__init__.py b/dubbo/serialization/__init__.py new file mode 100644 index 0000000..ee2ef61 --- /dev/null +++ b/dubbo/serialization/__init__.py @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._interfaces import Deserializer, SerializationError, Serializer, ensure_bytes +from .custom_serializers import CustomDeserializer, CustomSerializer +from .direct_serializers import DirectDeserializer, DirectSerializer + +__all__ = [ + "Serializer", + "Deserializer", + "SerializationError", + "ensure_bytes", + "DirectSerializer", + "DirectDeserializer", + "CustomSerializer", + "CustomDeserializer", +] diff --git a/dubbo/serialization/_interfaces.py b/dubbo/serialization/_interfaces.py new file mode 100644 index 0000000..65e808d --- /dev/null +++ b/dubbo/serialization/_interfaces.py @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Any, Union + +__all__ = ["Serializer", "Deserializer", "SerializationError", "ensure_bytes"] + + +class SerializationError(Exception): + """ + Serialization error. + """ + + def __init__(self, message: str): + super().__init__(message) + self.message = message + + def __str__(self): + return self.message + + +def ensure_bytes(obj: Union[bytes, bytearray, memoryview]) -> bytes: + """ + Ensure that the input object is bytes or can be converted to bytes. + :param obj: The object to ensure. + :type obj: Union[bytes, bytearray, memoryview] + :return: The bytes object. + :rtype: bytes + """ + + if isinstance(obj, bytes): + return obj + elif isinstance(obj, (bytearray, memoryview)): + return bytes(obj) + else: + raise SerializationError( + f"SerializationError: The incoming object is of type '{type(obj).__name__}', " + f"which is not supported. Expected types are 'bytes', 'bytearray', or 'memoryview'.\n" + f"Current object type: '{type(obj).__name__}'.\n" + f"Please provide data of the correct type or configure the serializer to handle the current input type." + ) + + +class Serializer(abc.ABC): + """ + Interface for serializer. + """ + + @abc.abstractmethod + def serialize(self, obj: Any) -> bytes: + """ + Serialize an object to bytes. + :param obj: The object to serialize. + :type obj: Any + :return: The serialized bytes. + :rtype: bytes + :raises SerializationError: If serialization fails. + """ + raise NotImplementedError() + + +class Deserializer(abc.ABC): + """ + Interface for deserializer. + """ + + @abc.abstractmethod + def deserialize(self, data: bytes) -> Any: + """ + Deserialize bytes to an object. + :param data: The bytes to deserialize. + :type data: bytes + :return: The deserialized object. + :rtype: Any + :raises SerializationError: If deserialization fails. + """ + raise NotImplementedError() diff --git a/dubbo/serialization/custom_serializers.py b/dubbo/serialization/custom_serializers.py new file mode 100644 index 0000000..c3ebceb --- /dev/null +++ b/dubbo/serialization/custom_serializers.py @@ -0,0 +1,85 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from dubbo.common.types import DeserializingFunction, SerializingFunction +from dubbo.serialization import ( + Deserializer, + SerializationError, + Serializer, + ensure_bytes, +) + +__all__ = ["CustomSerializer", "CustomDeserializer"] + + +class CustomSerializer(Serializer): + """ + Custom serializer that uses a custom serializing function to serialize objects. + """ + + __slots__ = ["serializer"] + + def __init__(self, serializer: SerializingFunction): + self.serializer = serializer + + def serialize(self, obj: Any) -> bytes: + """ + Serialize an object to bytes. + :param obj: The object to serialize. + :type obj: Any + :return: The serialized bytes. + :rtype: bytes + :raises SerializationError: If the object is not of type bytes, bytearray, or memoryview. + """ + try: + serialized_obj = self.serializer(obj) + except Exception as e: + raise SerializationError( + f"SerializationError: Failed to serialize object. Please check the serializer. \nDetails: {str(e)}", + ) + + return ensure_bytes(serialized_obj) + + +class CustomDeserializer(Deserializer): + """ + Custom deserializer that uses a custom deserializing function to deserialize objects. + """ + + __slots__ = ["deserializer"] + + def __init__(self, deserializer: DeserializingFunction): + self.deserializer = deserializer + + def deserialize(self, data: bytes) -> Any: + """ + Deserialize bytes to an object. + :param data: The bytes to deserialize. + :type data: bytes + :return: The deserialized object. + :rtype: Any + :raises SerializationError: If the object is not of type bytes, bytearray, or memoryview. + """ + try: + deserialized_obj = self.deserializer(data) + except Exception as e: + raise SerializationError( + f"SerializationError: Failed to deserialize object. Please check the deserializer. \nDetails: {str(e)}", + ) + + return deserialized_obj diff --git a/dubbo/serialization/direct_serializers.py b/dubbo/serialization/direct_serializers.py new file mode 100644 index 0000000..155a5a5 --- /dev/null +++ b/dubbo/serialization/direct_serializers.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from dubbo.common import SingletonBase +from dubbo.serialization import Deserializer, Serializer, ensure_bytes + +__all__ = ["DirectSerializer", "DirectDeserializer"] + + +class DirectSerializer(Serializer, SingletonBase): + """ + Direct serializer that does not perform any serialization. This serializer only checks if the given object is of + type bytes, bytearray, or memoryview and ensures it is returned as a bytes object. If the object is not of the + expected types, a SerializationError is raised. This serializer is a singleton. + """ + + def serialize(self, obj: Any) -> bytes: + """ + Serialize an object to bytes. + :param obj: The object to serialize. + :type obj: Any + :return: The serialized bytes. + :rtype: bytes + :raises SerializationError: If the object is not of type bytes, bytearray, or memoryview. + """ + return ensure_bytes(obj) if obj is not None else b"" + + +class DirectDeserializer(Deserializer): + """ + Direct deserializer that does not perform any serialization. This deserializer only returns the given bytes object + """ + + def deserialize(self, data: bytes) -> Any: + """ + Deserialize bytes to an object. + :param data: The bytes to deserialize. + :type data: bytes + :return: The deserialized object. + :rtype: Any + :raises SerializationError: If the object is not of type bytes, bytearray, or memoryview. + """ + return data diff --git a/dubbo/config/consumer_config.py b/dubbo/server.py similarity index 67% rename from dubbo/config/consumer_config.py rename to dubbo/server.py index 5037efe..3947913 100644 --- a/dubbo/config/consumer_config.py +++ b/dubbo/server.py @@ -14,17 +14,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dubbo.config.service_config import ServiceConfig +from dubbo.logger import loggerFactory -class ConsumerConfig: +_LOGGER = loggerFactory.get_logger(__name__) - def clone(self) -> "ConsumerConfig": + +class Server: + """ + Dubbo Server + """ + + __slots__ = ["_service"] + + def __init__(self, service_config: ServiceConfig): + self._service = service_config + + def start(self): """ - Clone the current configuration. - Returns: - ConsumerConfig: A new instance of ConsumerConfig. + Start the server """ - return ConsumerConfig() - - @classmethod - def default_config(cls): - return cls() + self._service.export() diff --git a/dubbo/url.py b/dubbo/url.py deleted file mode 100644 index 2178457..0000000 --- a/dubbo/url.py +++ /dev/null @@ -1,347 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import copy -from typing import Any, Dict, Optional -from urllib import parse - - -class URL: - """ - URL - Uniform Resource Locator. - Args: - scheme (str): The protocol of the URL. - host (str): The host of the URL. - port (int): The port number of the URL. - username (str): The username for URL authentication. - password (str): The password for URL authentication. - path (str): The path of the URL. - parameters (Dict[str, str]): The query parameters of the URL. - attributes (Dict[str, Any]): The attributes of the URL. (non-transferable) - - url example: - - http://www.facebook.com/friends?param1=value1¶m2=value2 - - http://username:password@10.20.130.230:8080/list?version=1.0.0 - - ftp://username:password@192.168.1.7:21/1/read.txt - - registry://192.168.1.7:9090/org.apache.dubbo.service1?param1=value1¶m2=value2 - """ - - __slots__ = [ - "_scheme", - "_host", - "_port", - "_location", - "_username", - "_password", - "_path", - "_parameters", - "_attributes", - ] - - def __init__( - self, - scheme: str, - host: str, - port: Optional[int] = None, - username: str = "", - password: str = "", - path: str = "", - parameters: Optional[Dict[str, str]] = None, - attributes: Optional[Dict[str, Any]] = None, - ): - self._scheme = scheme - self._host = host - self._port = port - # location -> host:port - self._location = f"{host}:{port}" if port else host - self._username = username - self._password = password - self._path = path - self._parameters = parameters or {} - self._attributes = attributes or {} - - @property - def scheme(self) -> str: - """ - Gets the protocol of the URL. - - Returns: - str: The protocol of the URL. - """ - return self._scheme - - @scheme.setter - def scheme(self, scheme: str) -> None: - """ - Sets the protocol of the URL. - - Args: - scheme (str): The protocol to set. - """ - self._scheme = scheme - - @property - def location(self) -> str: - """ - Gets the location (host:port) of the URL. - - Returns: - str: The location of the URL. - """ - return self._location - - @property - def host(self) -> str: - """ - Gets the host of the URL. - - Returns: - str: The host of the URL. - """ - return self._host - - @host.setter - def host(self, host: str) -> None: - """ - Sets the host of the URL. - - Args: - host (str): The host to set. - """ - self._host = host - self._location = f"{host}:{self.port}" if self.port else host - - @property - def port(self) -> Optional[int]: - """ - Gets the port of the URL. - - Returns: - int: The port of the URL. - """ - return self._port - - @port.setter - def port(self, port: int) -> None: - """ - Sets the port of the URL. - - Args: - port (int): The port to set. - """ - port = port if port > 0 else None - self._location = f"{self.host}:{port}" if port else self.host - - @property - def username(self) -> str: - """ - Gets the username for URL authentication. - - Returns: - str: The username for URL authentication. - """ - return self._username - - @username.setter - def username(self, username: str) -> None: - """ - Sets the username for URL authentication. - - Args: - username (str): The username to set. - """ - self._username = username - - @property - def password(self) -> str: - """ - Gets the password for URL authentication. - - Returns: - [str]: The password for URL authentication. - """ - return self._password - - @password.setter - def password(self, password: str) -> None: - """ - Sets the password for URL authentication. - - Args: - password (str): The password to set. - """ - self._password = password - - @property - def path(self) -> str: - """ - Gets the path of the URL. - - Returns: - str: The path of the URL. - """ - return self._path - - @path.setter - def path(self, path: str) -> None: - """ - Sets the path of the URL. - - Args: - path (str): The path to set. - """ - self._path = path - - def get_parameter(self, key: str) -> Optional[str]: - """ - Gets a query parameter from the URL. - - Args: - key (Optional[str]): The parameter name. - - Returns: - str or None: The parameter value. If the parameter does not exist, returns None. - """ - return self._parameters.get(key, None) - - def add_parameter(self, key: str, value: Any) -> None: - """ - Adds a query parameter to the URL. - - Args: - key (str): The parameter name. - value (Any): The parameter value. - """ - self._parameters[key] = str(value) if value is not None else "" - - @property - def attributes(self): - """ - Gets the attributes of the URL. - Returns: - Dict[str, Any]: The attributes of the URL. - """ - return self._attributes - - def build_string(self, encode: bool = False) -> str: - """ - Generates the URL string based on the current components. - - Args: - encode (bool): If True, the URL will be percent-encoded. - - Returns: - str: The generated URL string. - """ - # Set protocol - url = f"{self.scheme}://" if self.scheme else "" - # Set auth - if self.username: - url += f"{self.username}" - if self.password: - url += f":{self.password}" - url += "@" - # Set location - url += self.location if self.location else "" - # Set path - url += "/" - if self.path: - url += f"{self.path}" - # Set params - if self._parameters: - url += "?" + "&".join([f"{k}={v}" for k, v in self._parameters.items()]) - # If the URL needs to be encoded, encode it - if encode: - url = parse.quote(url) - return url - - def clone_without_attributes(self) -> "URL": - """ - Clones the URL object without the attributes. - Returns: - URL: The cloned URL object. - """ - return URL( - self.scheme, - self.host, - self.port, - self.username, - self.password, - self.path, - self._parameters.copy(), - ) - - def clone(self) -> "URL": - """ - Clones the URL object. Ignores the attributes. - - Returns: - URL: The cloned URL object. - """ - return URL( - self.scheme, - self.host, - self.port, - self.username, - self.password, - self.path, - self._parameters.copy(), - copy.deepcopy(self._attributes), - ) - - def __str__(self) -> str: - """ - Returns the URL string when the object is converted to a string. - - Returns: - str: The generated URL string. - """ - return self.build_string() - - @classmethod - def value_of(cls, url: str, encoded: bool = False) -> "URL": - """ - Creates a URL object from a URL string. - - Args: - url (str): The URL string to parse. format: [protocol://][username:password@][host:port]/[path] - encoded (bool): If True, the URL string is percent-encoded and will be decoded. - - Returns: - URL: The created URL object. - """ - if not url: - raise ValueError("URL string cannot be empty or None.") - - # If the URL is encoded, decode it - if encoded: - url = parse.unquote(url) - - if "://" not in url: - raise ValueError("Invalid URL format: missing protocol") - - parsed_url = parse.urlparse(url) - - protocol = parsed_url.scheme - host = parsed_url.hostname or "" - port = parsed_url.port or None - username = parsed_url.username or "" - password = parsed_url.password or "" - parameters = {k: v[0] for k, v in parse.parse_qs(parsed_url.query).items()} - path = parsed_url.path.lstrip("/") - - if not protocol: - raise ValueError("Invalid URL format: missing protocol.") - return URL(protocol, host, port, username, password, path, parameters) diff --git a/requirements.txt b/requirements.txt index 97fc58d..ca39f86 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ h2~=4.1.0 -uvloop~=0.19.0 \ No newline at end of file +uvloop~=0.19.0 +kazoo~=2.10.0 \ No newline at end of file diff --git a/tests/common/tets_url.py b/tests/common/tets_url.py index 912c939..f4133e5 100644 --- a/tests/common/tets_url.py +++ b/tests/common/tets_url.py @@ -15,23 +15,23 @@ # limitations under the License. import unittest -from dubbo.url import URL +from dubbo.common.url import URL, create_url class TestUrl(unittest.TestCase): def test_str_to_url(self): - url_0 = URL.value_of( + url_0 = create_url( "http://www.facebook.com/friends?param1=value1¶m2=value2" ) self.assertEqual("http", url_0.scheme) self.assertEqual("www.facebook.com", url_0.host) self.assertEqual(None, url_0.port) self.assertEqual("friends", url_0.path) - self.assertEqual("value1", url_0.get_parameter("param1")) - self.assertEqual("value2", url_0.get_parameter("param2")) + self.assertEqual("value1", url_0.parameters["param1"]) + self.assertEqual("value2", url_0.parameters["param2"]) - url_1 = URL.value_of("ftp://username:password@192.168.1.7:21/1/read.txt") + url_1 = create_url("ftp://username:password@192.168.1.7:21/1/read.txt") self.assertEqual("ftp", url_1.scheme) self.assertEqual("username", url_1.username) self.assertEqual("password", url_1.password) @@ -40,11 +40,11 @@ def test_str_to_url(self): self.assertEqual("192.168.1.7:21", url_1.location) self.assertEqual("1/read.txt", url_1.path) - url_2 = URL.value_of("file:///home/user1/router.js?type=script") + url_2 = create_url("file:///home/user1/router.js?type=script") self.assertEqual("file", url_2.scheme) self.assertEqual("home/user1/router.js", url_2.path) - url_3 = URL.value_of( + url_3 = create_url( "http%3A//www.facebook.com/friends%3Fparam1%3Dvalue1%26param2%3Dvalue2", encoded=True, ) @@ -52,8 +52,8 @@ def test_str_to_url(self): self.assertEqual("www.facebook.com", url_3.host) self.assertEqual(None, url_3.port) self.assertEqual("friends", url_3.path) - self.assertEqual("value1", url_3.get_parameter("param1")) - self.assertEqual("value2", url_3.get_parameter("param2")) + self.assertEqual("value1", url_3.parameters["param1"]) + self.assertEqual("value2", url_3.parameters["param2"]) def test_url_to_str(self): url_0 = URL( @@ -66,7 +66,7 @@ def test_url_to_str(self): parameters={"type": "a"}, ) self.assertEqual( - "tri://username:password@127.0.0.1:12/path?type=a", url_0.build_string() + "tri://username:password@127.0.0.1:12/path?type=a", url_0.to_str() ) url_1 = URL( @@ -76,7 +76,7 @@ def test_url_to_str(self): path="path", parameters={"type": "a"}, ) - self.assertEqual("tri://127.0.0.1:12/path?type=a", url_1.build_string()) + self.assertEqual("tri://127.0.0.1:12/path?type=a", url_1.to_str()) url_2 = URL(scheme="tri", host="127.0.0.1", port=12, parameters={"type": "a"}) - self.assertEqual("tri://127.0.0.1:12/?type=a", url_2.build_string()) + self.assertEqual("tri://127.0.0.1:12/?type=a", url_2.to_str()) diff --git a/tests/logger/__init__.py b/tests/logger/__init__.py deleted file mode 100644 index bcba37a..0000000 --- a/tests/logger/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/logger/test_logger_factory.py b/tests/logger/test_logger_factory.py deleted file mode 100644 index c3e6fd1..0000000 --- a/tests/logger/test_logger_factory.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import unittest - -from dubbo.constants import logger_constants as logger_constants -from dubbo.constants.logger_constants import Level -from dubbo.config import LoggerConfig -from dubbo.logger.logger_factory import loggerFactory -from dubbo.logger.logging.logger_adapter import LoggingLoggerAdapter - - -class TestLoggerFactory(unittest.TestCase): - - def test_without_config(self): - # Test the case where config is not used - logger = loggerFactory.get_logger("test_factory") - logger.info("info log -> without_config ") - - def test_with_config(self): - # Test the case where config is used - config = LoggerConfig.default_config() - config.init() - logger = loggerFactory.get_logger("test_factory") - logger.info("info log -> with_config ") - - url = config.get_url() - url.add_parameter(logger_constants.FILE_ENABLED_KEY, True) - loggerFactory.set_logger_adapter(LoggingLoggerAdapter(url)) - loggerFactory.set_level(Level.DEBUG) - logger = loggerFactory.get_logger("test_factory") - logger.debug("debug log -> with_config -> open file") - - url.add_parameter(logger_constants.CONSOLE_ENABLED_KEY, False) - loggerFactory.set_logger_adapter(LoggingLoggerAdapter(url)) - loggerFactory.set_level(Level.DEBUG) - logger.debug("debug log -> with_config -> lose console") diff --git a/tests/logger/test_logging_logger.py b/tests/logger/test_logging_logger.py deleted file mode 100644 index 9915dc0..0000000 --- a/tests/logger/test_logging_logger.py +++ /dev/null @@ -1,50 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import unittest - -from dubbo.constants.logger_constants import Level -from dubbo.config import LoggerConfig -from dubbo.logger.logging.logger_adapter import LoggingLoggerAdapter - - -class TestInternalLogger(unittest.TestCase): - - def test_log(self): - logger_adapter = LoggingLoggerAdapter( - config=LoggerConfig.default_config().get_url() - ) - logger = logger_adapter.get_logger("test") - logger.log(Level.INFO, "test log") - logger.debug("test debug") - logger.info("test info") - logger.warning("test warning") - logger.error("test error") - logger.critical("test critical") - logger.fatal("test fatal") - try: - 1 / 0 - except ZeroDivisionError: - logger.exception("test exception") - - # test different default logger level - logger_adapter.level = Level.INFO - logger.debug("debug can't be logged") - - logger_adapter.level = Level.WARNING - logger.info("info can't be logged") - - logger_adapter.level = Level.ERROR - logger.warning("warning can't be logged") From d17a8ff075a437ba4a41248a5e26ab1f1bcfa0fd Mon Sep 17 00:00:00 2001 From: zaki Date: Sun, 4 Aug 2024 14:42:09 +0800 Subject: [PATCH 30/32] docs: Comment completely using reStructuredText style --- dubbo/config/__init__.py | 1 - dubbo/config/application_config.py | 45 ------------- dubbo/config/logger_config.py | 54 ++++++++-------- dubbo/protocol/_interfaces.py | 24 ++++--- dubbo/protocol/invocation.py | 19 ++++-- dubbo/protocol/triple/constants.py | 6 +- dubbo/protocol/triple/protocol.py | 4 +- dubbo/remoting/aio/aio_transporter.py | 7 ++- dubbo/remoting/aio/event_loop.py | 16 ++--- dubbo/remoting/aio/http2/frames.py | 61 ++++++++++++------ dubbo/remoting/aio/http2/headers.py | 8 +-- dubbo/remoting/aio/http2/protocol.py | 57 ++++++++++------- dubbo/remoting/aio/http2/registries.py | 73 +++++++++++----------- dubbo/remoting/aio/http2/stream_handler.py | 29 ++++----- dubbo/remoting/aio/http2/utils.py | 8 +-- 15 files changed, 213 insertions(+), 199 deletions(-) delete mode 100644 dubbo/config/application_config.py diff --git a/dubbo/config/__init__.py b/dubbo/config/__init__.py index 63c4ec1..7ffd615 100644 --- a/dubbo/config/__init__.py +++ b/dubbo/config/__init__.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .application_config import ApplicationConfig from .logger_config import FileLoggerConfig, LoggerConfig from .protocol_config import ProtocolConfig from .reference_config import ReferenceConfig diff --git a/dubbo/config/application_config.py b/dubbo/config/application_config.py deleted file mode 100644 index 8ee0806..0000000 --- a/dubbo/config/application_config.py +++ /dev/null @@ -1,45 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class ApplicationConfig: - """ - Application configuration. - Attributes: - _name(str): Application name - _version(str): Application version - _owner(str): Application owner - _organization(str): Application organization (BU) - _environment(str): Application environment, e.g. dev, test or production - """ - - _name: str - _version: str - _owner: str - _organization: str - _environment: str - - def clone(self) -> "ApplicationConfig": - """ - Clone the current configuration. - Returns: - ApplicationConfig: A new instance of ApplicationConfig. - """ - return ApplicationConfig() - - @classmethod - def default_config(cls): - return cls() diff --git a/dubbo/config/logger_config.py b/dubbo/config/logger_config.py index f34ce13..ecae584 100644 --- a/dubbo/config/logger_config.py +++ b/dubbo/config/logger_config.py @@ -29,15 +29,20 @@ class FileLoggerConfig: """ File logger configuration. - Attributes: - rotate(FileRotateType): File rotate type. Optional: NONE,SIZE,TIME. Default: NONE. - file_formatter(Optional[str]): file format, if null, use global format. - file_dir(str): file directory. Default: user home dir - file_name(str): file name. Default: dubbo.log - backup_count(int): backup count. Default: 10 (when rotate is not NONE, backup_count is required) - max_bytes(int): maximum file size. Default: 1024.(when rotate is SIZE, max_bytes is required) - interval(int): interval time in seconds. Default: 1.(when rotate is TIME, interval is required, unit is day) - + :param rotate: File rotate type. + :type rotate: logger_constants.FileRotateType + :param file_formatter: File formatter. + :type file_formatter: Optional[str] + :param file_dir: File directory. + :type file_dir: str + :param file_name: File name. + :type file_name: str + :param backup_count: Backup count. + :type backup_count: int + :param max_bytes: Max bytes. + :type max_bytes: int + :param interval: Interval. + :type interval: int """ rotate: logger_constants.FileRotateType = logger_constants.FileRotateType.NONE @@ -68,24 +73,8 @@ def dict(self) -> Dict[str, str]: class LoggerConfig: """ Logger configuration. - - Attributes: - _driver(str): logger driver type. - _level(Level): logger level. - _console_enabled(bool): logger console enabled. - _file_enabled(bool): logger file enabled. - _file_config(FileLoggerConfig): logger file config. """ - # global - _driver: str - _level: Level - # console - _console_enabled: bool - # file - _file_enabled: bool - _file_config: FileLoggerConfig - __slots__ = [ "_driver", "_level", @@ -98,11 +87,24 @@ class LoggerConfig: def __init__( self, driver, - level, + level: Level, console_enabled: bool, file_enabled: bool, file_config: FileLoggerConfig, ): + """ + Initialize the logger configuration. + :param driver: The logger driver. + :type driver: str + :param level: The logger level. + :type level: Level + :param console_enabled: Whether to enable console logger. + :type console_enabled: bool + :param file_enabled: Whether to enable file logger. + :type file_enabled: bool + :param file_config: The file logger configuration. + :type file_config: FileLogger + """ # set global config self._driver = driver self._level = level diff --git a/dubbo/protocol/_interfaces.py b/dubbo/protocol/_interfaces.py index b3ba210..68f8f55 100644 --- a/dubbo/protocol/_interfaces.py +++ b/dubbo/protocol/_interfaces.py @@ -56,8 +56,8 @@ class Result(abc.ABC): def set_value(self, value: Any) -> None: """ Set the value of the result - Args: - value: Value to set + :param value: The value to set + :type value: Any """ raise NotImplementedError() @@ -72,8 +72,8 @@ def value(self) -> Any: def set_exception(self, exception: Exception) -> None: """ Set the exception to the result - Args: - exception: Exception to set + :param exception: The exception to set + :type exception: Exception """ raise NotImplementedError() @@ -94,8 +94,10 @@ class Invoker(Node, abc.ABC): def invoke(self, invocation: Invocation) -> Result: """ Invoke the service. - Returns: - Result: The result of the invocation. + :param invocation: The invocation. + :type invocation: Invocation + :return: The result. + :rtype: Result """ raise NotImplementedError() @@ -106,6 +108,8 @@ class Protocol(abc.ABC): def export(self, url: URL): """ Export a remote service. + :param url: The URL. + :type url: URL """ raise NotImplementedError() @@ -113,9 +117,9 @@ def export(self, url: URL): def refer(self, url: URL) -> Invoker: """ Refer a remote service. - Args: - url (URL): The URL of the remote service. - Returns: - Invoker: The invoker of the remote service. + :param url: The URL. + :type url: URL + :return: The invoker. + :rtype: Invoker """ raise NotImplementedError() diff --git a/dubbo/protocol/invocation.py b/dubbo/protocol/invocation.py index a3ac662..8e29800 100644 --- a/dubbo/protocol/invocation.py +++ b/dubbo/protocol/invocation.py @@ -22,12 +22,6 @@ class RpcInvocation(Invocation): """ The RpcInvocation class is an implementation of the Invocation interface. - Args: - service_name (str): The name of the service. - method_name (str): The name of the method. - argument (Any): The method argument. - attachments (Optional[Dict[str, str]]): Passed to the remote server during RPC call - attributes (Optional[Dict[str, Any]]): Only used on the caller side, will not appear on the wire. """ __slots__ = [ @@ -46,6 +40,19 @@ def __init__( attachments: Optional[Dict[str, str]] = None, attributes: Optional[Dict[str, Any]] = None, ): + """ + Initialize a new RpcInvocation instance. + :param service_name: The service name. + :type service_name: str + :param method_name: The method name. + :type method_name: str + :param argument: The argument. + :type argument: Any + :param attachments: The attachments. + :type attachments: Optional[Dict[str, str]] + :param attributes: The attributes. + :type attributes: Optional[Dict[str, Any]] + """ self._service_name = service_name self._method_name = method_name self._argument = argument diff --git a/dubbo/protocol/triple/constants.py b/dubbo/protocol/triple/constants.py index a51244e..98d71ad 100644 --- a/dubbo/protocol/triple/constants.py +++ b/dubbo/protocol/triple/constants.py @@ -78,8 +78,10 @@ class GRpcCode(enum.Enum): def from_code(cls, code: int) -> "GRpcCode": """ Get the RPC status code from the given code. - Args: - code: The RPC status code. + :param code: The RPC status code. + :type code: int + :return: The RPC status code. + :rtype: GRpcCode """ for rpc_code in cls: if rpc_code.value == code: diff --git a/dubbo/protocol/triple/protocol.py b/dubbo/protocol/triple/protocol.py index 9347fc8..c0dd386 100644 --- a/dubbo/protocol/triple/protocol.py +++ b/dubbo/protocol/triple/protocol.py @@ -89,8 +89,8 @@ def listener_factory(_path_resolver): def refer(self, url: URL) -> Invoker: """ Refer a remote service. - Args: - url (URL): The URL of the remote service. + :param url: The URL. + :type url: URL """ executor = ThreadPoolExecutor(thread_name_prefix="dubbo-tri-") # Create a stream handler diff --git a/dubbo/remoting/aio/aio_transporter.py b/dubbo/remoting/aio/aio_transporter.py index e721195..dd39803 100644 --- a/dubbo/remoting/aio/aio_transporter.py +++ b/dubbo/remoting/aio/aio_transporter.py @@ -33,8 +33,6 @@ class AioClient(Client): """ Asyncio client. - Args: - url(URL): The configuration of the client. """ __slots__ = [ @@ -47,6 +45,11 @@ class AioClient(Client): ] def __init__(self, url: URL): + """ + Initialize the client. + :param url: The URL. + :type url: URL + """ super().__init__(url) # Set the side of the transporter to client. diff --git a/dubbo/remoting/aio/event_loop.py b/dubbo/remoting/aio/event_loop.py index 5f0df4e..753be96 100644 --- a/dubbo/remoting/aio/event_loop.py +++ b/dubbo/remoting/aio/event_loop.py @@ -80,8 +80,8 @@ def __init__(self, in_other_tread: bool = True): def loop(self): """ Get the event loop. - Returns: - The event loop. + :return: The event loop. + :rtype: asyncio.AbstractEventLoop """ return self._loop @@ -89,26 +89,28 @@ def loop(self): def thread(self) -> Optional[threading.Thread]: """ Get the thread of the event loop. - Returns: - The thread of the event loop. If not yet started, this is None. + :return: The thread of the event loop. If not yet started, this is None. + :rtype: Optional[threading.Thread] """ return self._thread def check_thread(self) -> bool: """ Check if the current thread is the event loop thread. - Returns: - If the current thread is the event loop thread, return True. Otherwise, return False. + :return: True if the current thread is the event loop thread, otherwise False. + :rtype: bool """ return threading.current_thread().ident == self._thread.ident def is_started(self) -> bool: """ Check if the event loop is started. + :return: True if the event loop is started, otherwise False. + :rtype: bool """ return self._started - def start(self): + def start(self) -> None: """ Start the asyncio event loop. """ diff --git a/dubbo/remoting/aio/http2/frames.py b/dubbo/remoting/aio/http2/frames.py index 2733b8d..8967bd7 100644 --- a/dubbo/remoting/aio/http2/frames.py +++ b/dubbo/remoting/aio/http2/frames.py @@ -32,9 +32,6 @@ class Http2Frame: """ HTTP/2 frame class. It is used to represent an HTTP/2 frame. - Args: - stream_id: The stream identifier. - frame_type: The frame type. """ __slots__ = ["stream_id", "frame_type", "end_stream", "timestamp"] @@ -45,6 +42,15 @@ def __init__( frame_type: Http2FrameType, end_stream: bool = False, ): + """ + Initialize the HTTP/2 frame. + :param stream_id: The stream identifier. + :type stream_id: int + :param frame_type: The frame type. + :type frame_type: Http2FrameType + :param end_stream: Whether the stream is ended. + :type end_stream: bool + """ self.stream_id = stream_id self.frame_type = frame_type self.end_stream = end_stream @@ -56,10 +62,6 @@ def __repr__(self) -> str: class HeadersFrame(Http2Frame): """ HTTP/2 headers frame. - Args: - stream_id: The stream identifier. - headers: The HTTP/2 headers. - end_stream: Whether the stream is ended. """ __slots__ = ["headers"] @@ -70,6 +72,15 @@ def __init__( headers: Http2Headers, end_stream: bool = False, ): + """ + Initialize the HTTP/2 headers frame. + :param stream_id: The stream identifier. + :type stream_id: int + :param headers: The headers to send. + :type headers: Http2Headers + :param end_stream: Whether the stream is ended. + :type end_stream: bool + """ super().__init__(stream_id, Http2FrameType.HEADERS, end_stream) self.headers = headers @@ -80,11 +91,6 @@ def __repr__(self) -> str: class DataFrame(Http2Frame): """ HTTP/2 data frame. - Args: - stream_id: The stream identifier. - data: The data to send. - length: The amount of data received that counts against the flow control window. - end_stream: Whether the stream """ __slots__ = ["data", "padding"] @@ -96,6 +102,16 @@ def __init__( length: int, end_stream: bool = False, ): + """ + Initialize the HTTP/2 data frame. + :param stream_id: The stream identifier. + :type stream_id: int + :param data: The data to send. + :type data: bytes + :param length: The length of the data. + :type length: int + :param end_stream: Whether the stream is ended. + """ super().__init__(stream_id, Http2FrameType.DATA, end_stream) self.data = data self.padding = length @@ -107,9 +123,6 @@ def __repr__(self) -> str: class WindowUpdateFrame(Http2Frame): """ HTTP/2 window update frame. - Args: - stream_id: The stream identifier. - delta: The number of bytes by which to increase the flow control window. """ __slots__ = ["delta"] @@ -119,6 +132,13 @@ def __init__( stream_id: int, delta: int, ): + """ + Initialize the HTTP/2 window update frame. + :param stream_id: The stream identifier. + :type stream_id: int + :param delta: The delta value. + :type delta: int + """ super().__init__(stream_id, Http2FrameType.WINDOW_UPDATE, False) self.delta = delta @@ -129,9 +149,6 @@ def __repr__(self) -> str: class ResetStreamFrame(Http2Frame): """ HTTP/2 reset stream frame. - Args: - stream_id: The stream identifier. - error_code: The error code that indicates the reason for closing the stream. """ __slots__ = ["error_code"] @@ -141,6 +158,13 @@ def __init__( stream_id: int, error_code: Http2ErrorCode, ): + """ + Initialize the HTTP/2 reset stream frame. + :param stream_id: The stream identifier. + :type stream_id: int + :param error_code: The error code. + :type error_code: Http2ErrorCode + """ super().__init__(stream_id, Http2FrameType.RST_STREAM, True) self.error_code = error_code @@ -148,4 +172,5 @@ def __repr__(self) -> str: return f"" +# User action frames. UserActionFrames = Union[HeadersFrame, DataFrame, ResetStreamFrame] diff --git a/dubbo/remoting/aio/http2/headers.py b/dubbo/remoting/aio/http2/headers.py index f50e314..47311be 100644 --- a/dubbo/remoting/aio/http2/headers.py +++ b/dubbo/remoting/aio/http2/headers.py @@ -157,10 +157,10 @@ def __repr__(self) -> str: def from_list(cls, headers: List[Tuple[str, str]]) -> "Http2Headers": """ Create an Http2Headers object from a list. - Args: - headers: The headers list. - Returns: - The Http2Headers object. + :param headers: The headers list. + :type headers: List[Tuple[str, str]] + :return: The Http2Headers object. + :rtype: Http2Headers """ http2_headers = cls() http2_headers._headers = dict(headers) diff --git a/dubbo/remoting/aio/http2/protocol.py b/dubbo/remoting/aio/http2/protocol.py index 7276412..09e5661 100644 --- a/dubbo/remoting/aio/http2/protocol.py +++ b/dubbo/remoting/aio/http2/protocol.py @@ -93,8 +93,7 @@ def connection_made(self, transport: asyncio.Transport): def get_next_stream_id(self, future) -> None: """ Create a new stream.(thread-safe) - Args: - future: The future to set the stream identifier. + :param future: The future to set the stream identifier. """ def _inner_operation(_future): @@ -108,13 +107,15 @@ def send_frame( frame: UserActionFrames, stream: Http2Stream, event: Optional[asyncio.Event] = None, - ): + ) -> None: """ Send the HTTP/2 frame.(thread-unsafe) - Args: - frame: The HTTP/2 frame. - stream: The HTTP/2 stream. - event: The event to be set after sending the frame. + :param frame: The frame to send. + :type frame: UserActionFrames + :param stream: The stream. + :type stream: Http2Stream + :param event: The event to be set after sending the frame. + :type event: Optional[asyncio.Event] """ frame_type = frame.frame_type if frame_type == Http2FrameType.HEADERS: @@ -134,14 +135,16 @@ def _send_headers_frame( headers: List[Tuple[str, str]], end_stream: bool, event: Optional[asyncio.Event] = None, - ): + ) -> None: """ Send the HTTP/2 headers frame.(thread-unsafe) - Args: - stream_id: The stream identifier. - headers: The headers to send. - end_stream: Whether the stream is ended. - event: The event to be set after sending the frame. + :param stream_id: The stream identifier. + :type stream_id: int + :param headers: The headers. + :type headers: List[Tuple[str, str]] + :param end_stream: Whether the stream is ended. + :type end_stream: bool + :param event: The event to be set after sending the frame. """ self._h2_connection.send_headers(stream_id, headers, end_stream=end_stream) self._transport.write(self._h2_connection.data_to_send()) @@ -149,19 +152,26 @@ def _send_headers_frame( def _send_reset_frame( self, stream_id: int, error_code: int, event: Optional[asyncio.Event] = None - ): + ) -> None: """ Send the HTTP/2 reset frame.(thread-unsafe) - Args: - stream_id: The stream identifier. - error_code: The error code. - event: The event to be set after sending the frame + :param stream_id: The stream identifier. + :type stream_id: int + :param error_code: The error code. + :type error_code: int + :param event: The event to be set after sending the frame. + :type event: Optional[asyncio.Event] """ self._h2_connection.reset_stream(stream_id, error_code) self._transport.write(self._h2_connection.data_to_send()) EventHelper.set(event) def data_received(self, data): + """ + Called when some data is received from the transport. + :param data: The data received. + :type data: bytes + """ events = self._h2_connection.receive_data(data) # Process the event try: @@ -185,15 +195,16 @@ def data_received(self, data): except Exception as e: raise ProtocolError("Failed to process the Http/2 event.") from e - def ack_received_data(self, stream_id: int, padding: int): + def ack_received_data(self, stream_id: int, ack_length: int) -> None: """ Acknowledge the received data. - Args: - stream_id: The stream identifier. - padding: The amount of data received that counts against the flow control window. + :param stream_id: The stream identifier. + :type stream_id: int + :param ack_length: The length of the data to acknowledge. + :type ack_length: int """ - self._h2_connection.acknowledge_received_data(padding, stream_id) + self._h2_connection.acknowledge_received_data(ack_length, stream_id) self._transport.write(self._h2_connection.data_to_send()) def close(self): diff --git a/dubbo/remoting/aio/http2/registries.py b/dubbo/remoting/aio/http2/registries.py index fd07bf2..10e636d 100644 --- a/dubbo/remoting/aio/http2/registries.py +++ b/dubbo/remoting/aio/http2/registries.py @@ -15,7 +15,7 @@ # limitations under the License. import enum -from typing import Optional +from typing import Optional, Union __all__ = ["Http2FrameType", "Http2ErrorCode", "Http2Settings", "HttpStatus"] @@ -110,10 +110,8 @@ class Http2ErrorCode(enum.Enum): def get(cls, code: int): """ Get the error code by code. - Args: - code: The error code. - Returns: - The error code. + :param code: The error code. + :type code: int """ for error_code in cls: if error_code.value == code: @@ -237,56 +235,61 @@ def from_code(cls, code: int) -> "HttpStatus": return status @staticmethod - def is_1xx(status): + def is_1xx(status: Union["HttpStatus", int]) -> bool: """ Check if the given status is an informational (1xx) status code. - Args: - status: HttpStatus to check - Returns: - True if the status code is in the 1xx range, False otherwise + :param status: HttpStatus to check + :type status: Union[HttpStatus, int] + :return: True if the status code is in the 1xx range, False otherwise + :rtype: bool """ - return 100 <= status.value < 200 + value = status if isinstance(status, int) else status.value + return 100 <= value < 200 @staticmethod - def is_2xx(status): + def is_2xx(status: Union["HttpStatus", int]) -> bool: """ Check if the given status is a successful (2xx) status code. - Args: - status: HttpStatus to check - Returns: - True if the status code is in the 2xx range, False otherwise + :param status: HttpStatus to check + :type status: Union[HttpStatus, int] + :return: True if the status code is in the 2xx range, False otherwise + :rtype: bool """ - return 200 <= status.value < 300 + value = status if isinstance(status, int) else status.value + return 200 <= value < 300 @staticmethod - def is_3xx(status): + def is_3xx(status: Union["HttpStatus", int]) -> bool: """ Check if the given status is a redirection (3xx) status code. - Args: - status: HttpStatus to check - Returns: - True if the status code is in the 3xx range, False otherwise + :param status: HttpStatus to check + :type status: Union[HttpStatus, int] + :return: True if the status code is in the 3xx range, False otherwise + :rtype: bool """ - return 300 <= status.value < 400 + value = status if isinstance(status, int) else status.value + return 300 <= value < 400 @staticmethod - def is_4xx(status): + def is_4xx(status: Union["HttpStatus", int]) -> bool: """ Check if the given status is a client error (4xx) status code. - Args: - status: HttpStatus to check - Returns: - True if the status code is in the 4xx range, False otherwise + :param status: HttpStatus to check + :type status: Union[HttpStatus, int] + :return: True if the status code is in the 4xx range, False otherwise + :rtype: bool """ - return 400 <= status.value < 500 + value = status if isinstance(status, int) else status.value + return 400 <= value < 500 @staticmethod - def is_5xx(status): + def is_5xx(status: Union["HttpStatus", int]) -> bool: """ Check if the given status is a server error (5xx) status code. - Args: - status: HttpStatus to check - Returns: - True if the status code is in the 5xx range, False otherwise + :param status: HttpStatus to check + :type status: Union[HttpStatus, int] + :return: True if the status code is in the 5xx range, False otherwise + :rtype: bool """ - return 500 <= status.value < 600 + value = status if isinstance(status, int) else status.value + return 500 <= value < 600 diff --git a/dubbo/remoting/aio/http2/stream_handler.py b/dubbo/remoting/aio/http2/stream_handler.py index dfea951..49e127b 100644 --- a/dubbo/remoting/aio/http2/stream_handler.py +++ b/dubbo/remoting/aio/http2/stream_handler.py @@ -56,9 +56,10 @@ def __init__(self, executor: Optional[futures.ThreadPoolExecutor] = None): def do_init(self, loop: asyncio.AbstractEventLoop, protocol) -> None: """ Initialize the StreamMultiplexHandler.\ - Args: - loop: The asyncio event loop. - protocol: The HTTP/2 protocol. + :param loop: The event loop. + :type loop: asyncio.AbstractEventLoop + :param protocol: The HTTP/2 protocol. + :type protocol: Http2Protocol """ self._loop = loop self._protocol = protocol @@ -67,35 +68,35 @@ def do_init(self, loop: asyncio.AbstractEventLoop, protocol) -> None: def put_stream(self, stream_id: int, stream: DefaultHttp2Stream) -> None: """ Put the stream into the stream map. - Args: - stream_id: The stream identifier. - stream: The stream. + :param stream_id: The stream identifier. + :type stream_id: int + :param stream: The stream. + :type stream: DefaultHttp2Stream """ self._streams[stream_id] = stream def get_stream(self, stream_id: int) -> Optional[DefaultHttp2Stream]: """ Get the stream by stream identifier. - Args: - stream_id: The stream identifier. - Returns: - The stream. + :param stream_id: The stream identifier. + :type stream_id: int + :return: The stream. """ return self._streams.get(stream_id) def remove_stream(self, stream_id: int) -> None: """ Remove the stream by stream identifier. - Args: - stream_id: The stream identifier. + :param stream_id: The stream identifier. + :type stream_id: int """ self._streams.pop(stream_id, None) def handle_frame(self, frame: UserActionFrames) -> None: """ Handle the HTTP/2 frame. - Args: - frame: The HTTP/2 frame. + :param frame: The HTTP/2 frame. + :type frame: UserActionFrames """ stream = self._streams.get(frame.stream_id) if stream: diff --git a/dubbo/remoting/aio/http2/utils.py b/dubbo/remoting/aio/http2/utils.py index 4de376e..64f729d 100644 --- a/dubbo/remoting/aio/http2/utils.py +++ b/dubbo/remoting/aio/http2/utils.py @@ -41,10 +41,10 @@ def convert_to_frame( ) -> Union[HeadersFrame, DataFrame, ResetStreamFrame, WindowUpdateFrame, None]: """ Convert a h2.events.Event to HTTP/2 Frame. - Args: - event: The H2 event to convert. - Returns: - The converted HTTP/2 Frame. If the event is not supported, return None. + :param event: The H2 event. + :type event: h2.events.Event + :return: The HTTP/2 frame. + :rtype: Union[HeadersFrame, DataFrame, ResetStreamFrame, WindowUpdateFrame, None] """ if isinstance( event, From 16661810c7670c9626a84e3bb9bd55a00a110bc0 Mon Sep 17 00:00:00 2001 From: zaki Date: Sun, 4 Aug 2024 16:49:56 +0800 Subject: [PATCH 31/32] fix: update something --- dubbo/proxy/__init__.py | 4 ++-- dubbo/proxy/_interfaces.py | 30 +----------------------------- dubbo/proxy/callables.py | 14 +------------- 3 files changed, 4 insertions(+), 44 deletions(-) diff --git a/dubbo/proxy/__init__.py b/dubbo/proxy/__init__.py index 4c4ddd8..6080326 100644 --- a/dubbo/proxy/__init__.py +++ b/dubbo/proxy/__init__.py @@ -14,6 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._interfaces import RpcCallable, RpcCallableFactory +from ._interfaces import RpcCallable -__all__ = ["RpcCallable", "RpcCallableFactory"] +__all__ = ["RpcCallable"] diff --git a/dubbo/proxy/_interfaces.py b/dubbo/proxy/_interfaces.py index d6c9c98..0863a44 100644 --- a/dubbo/proxy/_interfaces.py +++ b/dubbo/proxy/_interfaces.py @@ -20,10 +20,7 @@ from dubbo.protocol import Invoker from dubbo.proxy.handlers import RpcServiceHandler -__all__ = [ - "RpcCallable", - "RpcCallableFactory", -] +__all__ = ["RpcCallable"] class RpcCallable(abc.ABC): @@ -34,28 +31,3 @@ def __call__(self, *args, **kwargs): call the rpc service """ raise NotImplementedError() - - -class RpcCallableFactory(abc.ABC): - - @abc.abstractmethod - def get_callable(self, invoker: Invoker, url: URL) -> RpcCallable: - """ - get the rpc proxy - :param invoker: the invoker. - :type invoker: Invoker - :param url: the url. - :type url: URL - """ - raise NotImplementedError() - - @abc.abstractmethod - def get_invoker(self, service_handler: RpcServiceHandler, url: URL) -> Invoker: - """ - get the rpc invoker - :param service_handler: the service handler. - :type service_handler: RpcServiceHandler - :param url: the url. - :type url: URL - """ - raise NotImplementedError() diff --git a/dubbo/proxy/callables.py b/dubbo/proxy/callables.py index 5f17098..f232489 100644 --- a/dubbo/proxy/callables.py +++ b/dubbo/proxy/callables.py @@ -20,7 +20,7 @@ from dubbo.common.url import URL from dubbo.protocol import Invoker from dubbo.protocol.invocation import RpcInvocation -from dubbo.proxy import RpcCallable, RpcCallableFactory +from dubbo.proxy import RpcCallable __all__ = ["MultipleRpcCallable"] @@ -60,15 +60,3 @@ def __call__(self, argument: Any) -> Any: # Do invoke. result = self._invoker.invoke(invocation) return result.value() - - -class DefaultRpcCallableFactory(RpcCallableFactory): - """ - The RpcCallableFactory class. - """ - - def get_callable(self, invoker: Invoker, url: URL) -> RpcCallable: - return MultipleRpcCallable(invoker, url) - - def get_invoker(self, service_handler: RpcServiceHandler, url: URL) -> Invoker: - pass From 51280afe3edbf5f6fa6c6eb713445274f8f886b4 Mon Sep 17 00:00:00 2001 From: zaki Date: Sun, 4 Aug 2024 16:52:08 +0800 Subject: [PATCH 32/32] fix: fix ci --- dubbo/proxy/_interfaces.py | 4 ---- dubbo/proxy/callables.py | 2 -- 2 files changed, 6 deletions(-) diff --git a/dubbo/proxy/_interfaces.py b/dubbo/proxy/_interfaces.py index 0863a44..fb04482 100644 --- a/dubbo/proxy/_interfaces.py +++ b/dubbo/proxy/_interfaces.py @@ -16,10 +16,6 @@ import abc -from dubbo.common import URL -from dubbo.protocol import Invoker -from dubbo.proxy.handlers import RpcServiceHandler - __all__ = ["RpcCallable"] diff --git a/dubbo/proxy/callables.py b/dubbo/proxy/callables.py index f232489..22dd793 100644 --- a/dubbo/proxy/callables.py +++ b/dubbo/proxy/callables.py @@ -24,8 +24,6 @@ __all__ = ["MultipleRpcCallable"] -from dubbo.proxy.handlers import RpcServiceHandler - class MultipleRpcCallable(RpcCallable): """