Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ packages = ["src/myob"]
[tool.ruff]
target-version = "py312"
line-length = 100
exclude = [
"build",
"dist",
"tests"
]

[tool.ruff.lint]
select = [
Expand All @@ -49,17 +54,18 @@ select = [
"I", # isort
"N", # pep8-naming
"UP", # pyupgrade
# "ANN", # flake8-annotations
"ANN", # flake8-annotations
"B", # flake8-bugbear
"S", # flake8-bandit
"T10", # debugger
"TID", # flake8-tidy-imports
]
ignore = [
"E501"
"E501",
"ANN401",
]

[tool.ruff.lint.isort]
extra-standard-library = [
"requests",
]
]
30 changes: 16 additions & 14 deletions src/myob/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from .credentials import PartnerCredentials
from .endpoints import ALL, ENDPOINTS, GET
from .managers import Manager
Expand All @@ -6,7 +8,7 @@
class Myob:
"""An ORM-like interface to the MYOB API."""

def __init__(self, credentials):
def __init__(self, credentials: PartnerCredentials) -> None:
if not isinstance(credentials, PartnerCredentials):
raise TypeError(f"Expected a Credentials instance, got {type(credentials).__name__}.")
self.credentials = credentials
Expand All @@ -23,16 +25,16 @@ def __init__(self, credentials):
],
)

def info(self):
return self._manager.info()
def info(self) -> str:
return self._manager.info() # type: ignore[attr-defined]

def __repr__(self):
def __repr__(self) -> str:
options = "\n ".join(["companyfiles", "info"])
return f"Myob:\n {options}"


class CompanyFiles:
def __init__(self, credentials):
def __init__(self, credentials: PartnerCredentials) -> None:
self.credentials = credentials
self._manager = Manager(
"",
Expand All @@ -44,42 +46,42 @@ def __init__(self, credentials):
)
self._manager.name = "CompanyFile"

def all(self):
raw_companyfiles = self._manager.all()
def all(self) -> list["CompanyFile"]:
raw_companyfiles = self._manager.all() # type: ignore[attr-defined]
return [
CompanyFile(raw_companyfile, self.credentials) for raw_companyfile in raw_companyfiles
]

def get(self, id, call=True):
def get(self, id: str, call: bool = True) -> "CompanyFile":
if call:
# raw_companyfile = self._manager.get(id=id)['CompanyFile']
# NOTE: Annoyingly, we need to pass company_id to the manager, else we won't have permission
# on the GET endpoint. The only way we currently allow passing company_id is by setting it on the manager,
# and we can't do that on init, as this is a manager for company files plural..
# Reluctant to change manager code, as it would add confusion if the inner method let you override the company_id.
manager = Manager("", self.credentials, raw_endpoints=[(GET, "", "")], company_id=id)
raw_companyfile = manager.get()["CompanyFile"]
raw_companyfile = manager.get()["CompanyFile"] # type: ignore[attr-defined]
else:
raw_companyfile = {"Id": id}
return CompanyFile(raw_companyfile, self.credentials)

def __repr__(self):
def __repr__(self) -> str:
return self._manager.__repr__()


class CompanyFile:
def __init__(self, raw, credentials):
def __init__(self, raw: dict[str, Any], credentials: PartnerCredentials) -> None:
self.id = raw["Id"]
self.name = raw.get("Name")
self.data = raw # Dump remaining raw data here.
self.credentials = credentials
for k, v in ENDPOINTS.items():
setattr(
self,
v["name"],
v["name"], # type: ignore[arg-type]
Manager(k, credentials, endpoints=v["methods"], company_id=self.id),
)

def __repr__(self):
options = "\n ".join(sorted(v["name"] for v in ENDPOINTS.values()))
def __repr__(self) -> str:
options = "\n ".join(sorted(v["name"] for v in ENDPOINTS.values())) # type: ignore[misc]
return f"CompanyFile:\n {options}"
45 changes: 23 additions & 22 deletions src/myob/credentials.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import datetime
from datetime import datetime, timedelta
from typing import Any

from requests_oauthlib import OAuth2Session

Expand All @@ -11,17 +12,17 @@ class PartnerCredentials:

def __init__(
self,
consumer_key,
consumer_secret,
callback_uri,
verified=False,
companyfile_credentials={}, # noqa: B006
oauth_token=None,
refresh_token=None,
oauth_expires_at=None,
scope=None,
state=None,
):
consumer_key: str,
consumer_secret: str,
callback_uri: str,
verified: bool = False,
companyfile_credentials: dict[str, str] = {}, # noqa: B006
oauth_token: str | None = None,
refresh_token: str | None = None,
oauth_expires_at: datetime | None = None,
scope: None = None, # TODO: Review if used.
state: str | None = None,
) -> None:
self.consumer_key = consumer_key
self.consumer_secret = consumer_secret
self.callback_uri = callback_uri
Expand All @@ -32,7 +33,7 @@ def __init__(
self.refresh_token = refresh_token

if oauth_expires_at is not None:
if not isinstance(oauth_expires_at, datetime.datetime):
if not isinstance(oauth_expires_at, datetime):
raise ValueError("'oauth_expires_at' must be a datetime instance.")
self.oauth_expires_at = oauth_expires_at

Expand All @@ -42,13 +43,13 @@ def __init__(

# TODO: Add `verify` kwarg here, which will quickly throw the provided credentials at a
# protected endpoint to ensure they are valid. If not, raise appropriate error.
def authenticate_companyfile(self, company_id, username, password):
def authenticate_companyfile(self, company_id: str, username: str, password: str) -> None:
"""Store hashed username-password for logging into company file."""
userpass = base64.b64encode(bytes(f"{username}:{password}", "utf-8")).decode("utf-8")
self.companyfile_credentials[company_id] = userpass

@property
def state(self):
def state(self) -> dict[str, Any]:
"""Get a representation of this credentials object from which it can be reconstructed."""
return {
attr: getattr(self, attr)
Expand All @@ -65,7 +66,7 @@ def state(self):
if getattr(self, attr) is not None
}

def expired(self, now=None):
def expired(self, now: datetime | None = None) -> bool:
"""Determine whether the current access token has expired."""
# Expiry might be unset if the user hasn't finished authenticating.
if self.oauth_expires_at is None:
Expand All @@ -76,10 +77,10 @@ def expired(self, now=None):
# they can use self.oauth_expires_at
CONSERVATIVE_SECONDS = 30 # noqa: N806

now = now or datetime.datetime.now()
return self.oauth_expires_at <= (now + datetime.timedelta(seconds=CONSERVATIVE_SECONDS))
now = now or datetime.now()
return self.oauth_expires_at <= (now + timedelta(seconds=CONSERVATIVE_SECONDS))

def verify(self, code):
def verify(self, code: str) -> None:
"""Verify an OAuth session, retrieving an access token."""
token = self._oauth.fetch_token(
MYOB_PARTNER_BASE_URL + ACCESS_TOKEN_URL,
Expand All @@ -89,7 +90,7 @@ def verify(self, code):
)
self.save_token(token)

def refresh(self):
def refresh(self) -> None:
"""Refresh an expired token."""
token = self._oauth.refresh_token(
MYOB_PARTNER_BASE_URL + ACCESS_TOKEN_URL,
Expand All @@ -99,9 +100,9 @@ def refresh(self):
)
self.save_token(token)

def save_token(self, token):
def save_token(self, token: dict) -> None:
self.oauth_token = token.get("access_token")
self.refresh_token = token.get("refresh_token")

self.oauth_expires_at = datetime.datetime.fromtimestamp(token.get("expires_at"))
self.oauth_expires_at = datetime.fromtimestamp(token.get("expires_at")) # type: ignore[arg-type]
self.verified = True
13 changes: 7 additions & 6 deletions src/myob/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from .types import Method
from .utils import pluralise

ALL = "ALL"
GET = "GET" # this method expects a UID as a keyword
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
ALL: Method = "ALL"
GET: Method = "GET" # this method expects a UID as a keyword
POST: Method = "POST"
PUT: Method = "PUT"
DELETE: Method = "DELETE"
CRUD = "CRUD" # shorthand for creating the ALL|GET|POST|PUT|DELETE endpoints in one swoop

METHOD_ORDER = [ALL, GET, POST, PUT, DELETE]
METHOD_ORDER: list[Method] = [ALL, GET, POST, PUT, DELETE]

ENDPOINTS = {
"Banking/": {
Expand Down
5 changes: 4 additions & 1 deletion src/myob/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from requests import Response


class MyobException(Exception): # noqa: N818
def __init__(self, response, msg=None):
def __init__(self, response: Response, msg: str | None = None) -> None:
self.response = response
try:
self.errors = response.json()["Errors"]
Expand Down
48 changes: 29 additions & 19 deletions src/myob/managers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import re
import requests
from datetime import date
from typing import Any

from .constants import DEFAULT_PAGE_SIZE, MYOB_BASE_URL
from .endpoints import CRUD, METHOD_MAPPING, METHOD_ORDER
from .credentials import PartnerCredentials
from .endpoints import ALL, CRUD, GET, METHOD_MAPPING, METHOD_ORDER, POST, PUT, Method
from .exceptions import (
MyobBadRequest,
MyobConflict,
Expand All @@ -15,18 +17,26 @@
MyobRateLimitExceeded,
MyobUnauthorized,
)
from .types import MethodDetails


class Manager:
def __init__(self, name, credentials, company_id=None, endpoints=[], raw_endpoints=[]): # noqa: B006
def __init__(
self,
name: str,
credentials: PartnerCredentials,
company_id: str | None = None,
endpoints: list = [], # noqa: B006
raw_endpoints: list = [], # noqa: B006
) -> None:
self.credentials = credentials
self.name = "_".join(p for p in name.rstrip("/").split("/") if "[" not in p)
self.base_url = MYOB_BASE_URL
if company_id is not None:
self.base_url += company_id + "/"
if name:
self.base_url += name
self.method_details = {}
self.method_details: dict[str, MethodDetails] = {}
self.company_id = company_id

# Build ORM methods from given url endpoints.
Expand All @@ -48,16 +58,16 @@ def __init__(self, name, credentials, company_id=None, endpoints=[], raw_endpoin
for method, endpoint, hint in raw_endpoints:
self.build_method(method, endpoint, hint)

def build_method(self, method, endpoint, hint):
def build_method(self, method: Method, endpoint: str, hint: str) -> None:
full_endpoint = self.base_url + endpoint
url_keys = re.findall(r"\[([^\]]*)\]", full_endpoint)
template = full_endpoint.replace("[", "{").replace("]", "}")

required_kwargs = url_keys.copy()
if method in ("PUT", "POST"):
if method in (PUT, POST):
required_kwargs.append("data")

def inner(*args, timeout=None, **kwargs):
def inner(*args: Any, timeout: int | None = None, **kwargs: Any) -> str | dict:
if args:
raise AttributeError("Unnamed args provided. Only keyword args accepted.")

Expand All @@ -78,7 +88,7 @@ def inner(*args, timeout=None, **kwargs):
request_kwargs_raw[k] = v

# Determine request method.
request_method = "GET" if method == "ALL" else method
request_method = GET if method == ALL else method

# Build url.
url = template.format(**url_kwargs)
Expand Down Expand Up @@ -130,13 +140,13 @@ def inner(*args, timeout=None, **kwargs):
# If it already exists, prepend with method to disambiguate.
elif hasattr(self, method_name):
method_name = f"{method.lower()}_{method_name}"
self.method_details[method_name] = {
"kwargs": required_kwargs,
"hint": hint,
}
self.method_details[method_name] = MethodDetails(
kwargs=required_kwargs,
hint=hint,
)
setattr(self, method_name, inner)

def build_request_kwargs(self, method, data=None, **kwargs):
def build_request_kwargs(self, method: Method, data: dict | None = None, **kwargs: Any) -> dict:
request_kwargs = {}

# Build headers.
Expand Down Expand Up @@ -166,7 +176,7 @@ def build_request_kwargs(self, method, data=None, **kwargs):
request_kwargs["params"] = {}
filters = []

def build_value(value):
def build_value(value: Any) -> str:
if issubclass(type(value), date):
return f"datetime'{value}'"
if isinstance(value, bool):
Expand Down Expand Up @@ -205,10 +215,10 @@ def build_value(value):
page_size = DEFAULT_PAGE_SIZE
if "limit" in kwargs:
page_size = int(kwargs["limit"])
request_kwargs["params"]["$top"] = page_size
request_kwargs["params"]["$top"] = page_size # type: ignore[assignment]

if "page" in kwargs:
request_kwargs["params"]["$skip"] = (int(kwargs["page"]) - 1) * page_size
request_kwargs["params"]["$skip"] = (int(kwargs["page"]) - 1) * page_size # type: ignore[assignment]

if "format" in kwargs:
request_kwargs["params"]["format"] = kwargs["format"]
Expand All @@ -225,16 +235,16 @@ def build_value(value):

return request_kwargs

def __repr__(self):
def _get_signature(name, kwargs):
def __repr__(self) -> str:
def _get_signature(name: str, kwargs: list[str]) -> str:
return f"{name}({', '.join(kwargs)})"

def print_method(name, kwargs, hint, offset):
def _print_method(name: str, kwargs: list[str], hint: str, offset: int) -> str:
return f"{_get_signature(name, kwargs):>{offset}} - {hint}"

offset = max(len(_get_signature(k, v["kwargs"])) for k, v in self.method_details.items())
options = "\n ".join(
print_method(k, v["kwargs"], v["hint"], offset)
_print_method(k, v["kwargs"], v["hint"], offset)
for k, v in sorted(self.method_details.items())
)
return f"{self.name}{self.__class__.__name__}:\n {options}"
Loading
Loading