diff --git a/auto_dev/data/templates/customs/handler_header.jinja b/auto_dev/data/templates/customs/handler_header.jinja index b1faf528..4871c393 100644 --- a/auto_dev/data/templates/customs/handler_header.jinja +++ b/auto_dev/data/templates/customs/handler_header.jinja @@ -48,15 +48,6 @@ from .exceptions import ( {% endfor %} -@dataclass -class ApiResponse: - """Api response.""" - headers: dict[str, str] - content: bytes - status_code: int - status_text: str - - class ApiHttpHandler(Handler): """Implements the API HTTP handler.""" diff --git a/auto_dev/data/templates/customs/method_template.jinja b/auto_dev/data/templates/customs/method_template.jinja index 3d99f047..d54c59eb 100644 --- a/auto_dev/data/templates/customs/method_template.jinja +++ b/auto_dev/data/templates/customs/method_template.jinja @@ -1,4 +1,4 @@ - def {{ method_name }}(self, _message: ApiHttpMessage{% for param in path_params_snake_case %}, {{ param }}{% endfor %}{% if method|lower in ['post', 'put', 'patch', 'delete'] %}, body{% endif %}): + def {{ method_name }}(self, message: ApiHttpMessage{% for param in path_params_snake_case %}, {{ param }}{% endfor %}{% if method|lower in ['post', 'put', 'patch', 'delete'] %}, body{% endif %}): {% raw %}"""{% endraw %}Handle {{ method|upper }} request for {{ path }}.{% raw %}"""{% endraw %} {%- if path_params_snake_case %} self.context.logger.debug(f"Path parameters: {% for param in path_params_snake_case %}{{ param }}={{ '{' }}{{ param }}{{ '}' }}{% if not loop.last %}, {% endif %}{% endfor %}") @@ -9,12 +9,25 @@ try: {%- if method|lower == 'get' and not path_params %} result = {{ schema|lower|replace(' ', '_') }}_dao.get_all() + {%- elif method|lower == 'get' and path_params %} + result = {{ schema|lower|replace(' ', '_') }}_dao.get_by_id({{ path_params_snake_case[0] }}) + + if result is None: + error_message = json.dumps({"error": f"{{ schema|replace(' ', '_') }} with {{ path_params[0] }} {{ '{' ~ path_params_snake_case[0] ~ '}' }} not found"}) + return ApiHttpMessage( + performative=ApiHttpMessage.Performative.RESPONSE, + status_code=404, + status_text="Not Found", + headers="", + version=message.version, + body=error_message.encode(), + ) {%- elif method|lower == 'post' and operation_type == 'insert' %} {{ schema|lower|replace(' ', '_') }}_body = json.loads(body) result = {{ schema|lower|replace(' ', '_') }}_dao.insert({{ schema|lower|replace(' ', '_') }}_body) {%- elif method|lower == 'post' and operation_type == 'update' %} {{ schema|lower|replace(' ', '_') }}_body = json.loads(body) - result = {{ schema|lower|replace(' ', '_') }}_dao.update({{ path_params_snake_case[0] }}, {{ schema|lower|replace(' ', '_') }}_body) + result = {{ schema|lower|replace(' ', '_') }}_dao.update({{ path_params_snake_case[0] }}, **{{ schema|lower|replace(' ', '_') }}_body) {%- else %} # TODO: Implement {{ method|upper }} logic for {{ path }} raise NotImplementedError @@ -22,27 +35,34 @@ self.context.logger.info("Successfully processed {{ method|upper }} request for {{ path }}") self.context.logger.debug(f"Result: {result}") - return ApiResponse( - headers={{ headers }}, - content=result, + return ApiHttpMessage( + performative=ApiHttpMessage.Performative.RESPONSE, status_code={{ status_code }}, - status_text="{{ status_text }}" + status_text="{{ status_text }}", + headers="", + version=message.version, + body=json.dumps(result).encode() ) {% for error_code, error_info in error_responses.items() %} except {{ error_info.exception }}: self.context.logger.exception("{{ error_info.message }}") - return ApiResponse( - headers={{ headers }}, - content=json.dumps({"error": "{{ error_info.message }}"}).encode("utf-8"), + return ApiHttpMessage( + performative=ApiHttpMessage.Performative.RESPONSE, status_code={{ error_code }}, - status_text="{{ error_info.status_text }}" + status_text="{{ error_info.status_text }}", + headers="", + version=message.version, + body=json.dumps({"error": "{{ error_info.message }}"}).encode("utf-8") ) {% endfor %} except Exception as e: self.context.logger.exception("Unhandled exception") - return ApiResponse( - headers={{ headers }}, - content=json.dumps({"error": str(e)}).encode("utf-8"), + error_message = json.dumps({"error": str(e)}) + return ApiHttpMessage( + performative=ApiHttpMessage.Performative.RESPONSE, status_code=500, - status_text="Internal Server Error" + headers="", + version=message.version, + status_text="Internal Server Error", + body=error_message.encode(), ) diff --git a/auto_dev/data/templates/dao/base_dao.jinja b/auto_dev/data/templates/dao/base_dao.jinja index 362ea97b..c3142ac0 100644 --- a/auto_dev/data/templates/dao/base_dao.jinja +++ b/auto_dev/data/templates/dao/base_dao.jinja @@ -22,6 +22,7 @@ class BaseDAO: def __post_init__(self): """Post initialization setup.""" self.logger = logging.getLogger(f"aea.{self.__class__.__name__}") + self.load_data() @property def data(self) -> dict[str, Any]: @@ -80,8 +81,6 @@ class BaseDAO: def insert(self, data: dict[str, Any] | list[dict[str, Any]]) -> None: """Insert a new item or items into the data.""" - self.load_data() - if self.model_name not in self._data: self._data[self.model_name] = {} diff --git a/auto_dev/data/templates/dao/dao_template.jinja b/auto_dev/data/templates/dao/dao_template.jinja index 23a6ac3e..e51c683f 100644 --- a/auto_dev/data/templates/dao/dao_template.jinja +++ b/auto_dev/data/templates/dao/dao_template.jinja @@ -1,6 +1,6 @@ """{{ model_name }} DAO.""" -from base_dao import BaseDAO +from .base_dao import BaseDAO class {{ model_name }}DAO(BaseDAO):