Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ jobs:
run: |
echo "API_KEY_NAME=$(echo ${{ format('MP_API_KEY_{0}_{1}', matrix.os, matrix.python-version) }} | awk '{gsub(/-|\./, "_"); print}' | tr '[:lower:]' '[:upper:]')" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append

- name: Lint with mypy
shell: bash -l {0}
run: python -m mypy mp_api/

- name: Test with pytest
env:
MP_API_KEY: ${{ secrets[env.API_KEY_NAME] }}
Expand Down
Empty file added mp_api/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion mp_api/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
try:
__version__ = version("mp_api")
except PackageNotFoundError: # pragma: no cover
__version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION")
__version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION", "")
52 changes: 31 additions & 21 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from tqdm.auto import tqdm
from urllib3.util.retry import Retry

from mp_api.client.core.exceptions import MPRestError
from mp_api.client.core.exceptions import MPRestError, _emit_status_warning
from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS
from mp_api.client.core.utils import (
load_json,
Expand All @@ -46,8 +46,10 @@

try:
import flask

_flask_is_installed = True
except ImportError:
flask = None
_flask_is_installed = False

if TYPE_CHECKING:
from typing import Any, Callable
Expand All @@ -59,7 +61,7 @@
try:
__version__ = version("mp_api")
except PackageNotFoundError: # pragma: no cover
__version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION")
__version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION", "")


class _DictLikeAccess(BaseModel):
Expand All @@ -83,7 +85,7 @@ class BaseRester:
"""Base client class with core stubs."""

suffix: str = ""
document_model: type[BaseModel] | None = None
document_model: type[BaseModel] = _DictLikeAccess
primary_key: str = "material_id"

def __init__(
Expand Down Expand Up @@ -222,7 +224,10 @@ def _get_database_version(endpoint):

Returns: database version as a string
"""
return requests.get(url=endpoint + "heartbeat").json()["db_version"]
if (get_resp := requests.get(url=endpoint + "heartbeat")).status_code == 403:
_emit_status_warning()
return
return get_resp.json()["db_version"]

def _post_resource(
self,
Expand Down Expand Up @@ -407,7 +412,7 @@ def _query_resource(
num_chunks: int | None = None,
chunk_size: int | None = None,
timeout: int | None = None,
) -> dict:
) -> dict[str, Any]:
"""Query the endpoint for a Resource containing a list of documents
and meta information about pagination and total document count.

Expand Down Expand Up @@ -539,12 +544,15 @@ def _query_resource(
for docs, _, _ in byte_data:
unzipped_data.extend(docs)

data = {"data": unzipped_data, "meta": {}}

if self.use_document_model:
data["data"] = self._convert_to_model(data["data"])
data: dict[str, Any] = {
"data": (
self._convert_to_model(unzipped_data) # type: ignore[arg-type]
if self.use_document_model
else unzipped_data
),
"meta": {"total_doc": len(unzipped_data)},
}

data["meta"]["total_doc"] = len(data["data"])
else:
data = self._submit_requests(
url=url,
Expand Down Expand Up @@ -672,7 +680,7 @@ def _submit_requests( # noqa
new_limits = [chunk_size]

total_num_docs = 0
total_data: dict[str, list[Any]] = {"data": []}
total_data: dict[str, Any] = {"data": []}

# Obtain first page of results and get pagination information.
# Individual total document limits (subtotal) will potentially
Expand Down Expand Up @@ -871,7 +879,7 @@ def _multi_thread(
func: Callable,
params_list: list[dict],
progress_bar: tqdm | None = None,
):
) -> list[tuple[Any, int, int]]:
"""Handles setting up a threadpool and sending parallel requests.

Arguments:
Expand Down Expand Up @@ -962,7 +970,7 @@ def _submit_request_and_process(
Tuple with data and total number of docs in matching the query in the database.
"""
headers = None
if flask is not None and flask.has_request_context():
if _flask_is_installed and flask.has_request_context():
headers = flask.request.headers

try:
Expand Down Expand Up @@ -1015,7 +1023,9 @@ def _submit_request_and_process(
f"on URL {response.url} with message:\n{message}"
)

def _convert_to_model(self, data: list[dict]):
def _convert_to_model(
self, data: list[dict[str, Any]]
) -> list[BaseModel] | list[dict[str, Any]]:
"""Converts dictionary documents to instantiated MPDataDoc objects.

Args:
Expand All @@ -1028,7 +1038,7 @@ def _convert_to_model(self, data: list[dict]):
if len(data) > 0:
data_model, set_fields, _ = self._generate_returned_model(data[0])

data = [
return [
data_model(
**{
field: value
Expand All @@ -1043,7 +1053,7 @@ def _convert_to_model(self, data: list[dict]):

def _generate_returned_model(
self, doc: dict[str, Any]
) -> tuple[BaseModel, list[str], list[str]]:
) -> tuple[type[BaseModel], list[str], list[str]]:
model_fields = self.document_model.model_fields
set_fields = [k for k in doc if k in model_fields]
unset_fields = [field for field in model_fields if field not in set_fields]
Expand All @@ -1059,13 +1069,13 @@ def _generate_returned_model(
):
vars(import_module(self.document_model.__module__))

include_fields: dict[str, tuple[type, FieldInfo]] = {}
include_fields: dict[str, tuple[Any, FieldInfo]] = {}
for name in set_fields:
field_copy = model_fields[name]._copy()
if not field_copy.default_factory:
# Fields with a default_factory cannot also have a default in pydantic>=2.12.3
field_copy.default = None
include_fields[name] = (
include_fields[name] = ( # type: ignore[assignment]
Optional[model_fields[name].annotation],
field_copy,
)
Expand Down Expand Up @@ -1202,7 +1212,7 @@ def get_data_by_id(
self,
document_id: str,
fields: list[str] | None = None,
) -> BaseModel | dict:
) -> BaseModel | dict[str, Any] | None:
warnings.warn(
"get_data_by_id is deprecated and will be removed soon. Please use the search method instead.",
DeprecationWarning,
Expand All @@ -1221,7 +1231,7 @@ def get_data_by_id(
if isinstance(fields, str): # pragma: no cover
fields = (fields,) # type: ignore

docs = self._search( # type: ignorech( # type: ignorech( # type: ignore
docs = self._search(
**{self.primary_key + "s": document_id},
num_chunks=1,
chunk_size=1,
Expand Down
12 changes: 12 additions & 0 deletions mp_api/client/core/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
"""Define custom exceptions and warnings for the client."""
from __future__ import annotations

import warnings


class MPRestError(Exception):
"""Raised when the query has problems, e.g., bad query format."""


class MPRestWarning(Warning):
"""Raised when a query is malformed but interpretable."""


def _emit_status_warning() -> None:
"""Emit a warning if client can't hear a heartbeat."""
warnings.warn(
"Cannot listen to heartbeat, check Materials Project "
"status page: https://status.materialsproject.org/",
category=MPRestWarning,
stacklevel=2,
)
4 changes: 2 additions & 2 deletions mp_api/client/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
_MAX_HTTP_URL_LENGTH = PMG_SETTINGS.get("MPRESTER_MAX_HTTP_URL_LENGTH", 2000)
_MAX_LIST_LENGTH = min(PMG_SETTINGS.get("MPRESTER_MAX_LIST_LENGTH", 10000), 10000)

_EMMET_SETTINGS = EmmetSettings()
_EMMET_SETTINGS = EmmetSettings() # type: ignore[call-arg]
_DEFAULT_ENDPOINT = "https://api.materialsproject.org/"


Expand Down Expand Up @@ -109,4 +109,4 @@ def _get_endpoint_from_env(cls, v: str | None) -> str:
return v or os.environ.get("MP_API_ENDPOINT") or _DEFAULT_ENDPOINT


MAPI_CLIENT_SETTINGS = MAPIClientSettings()
MAPI_CLIENT_SETTINGS: MAPIClientSettings = MAPIClientSettings() # type: ignore[call-arg]
12 changes: 6 additions & 6 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,11 @@ def __init__(
import_str : str
A dot-separated, import-like string.
"""
if len(split_import_str := import_str.rsplit(".", 1)) > 1:
self._module_name, self._class_name = split_import_str
if len(split_import_str := import_str.rsplit(".", 1)) == 1:
self._module_name: str = split_import_str[0]
self._class_name: str | None = None
else:
self._module_name = split_import_str[0]
self._class_name = None
self._module_name, self._class_name = split_import_str

self._imported: Any | None = None
self._obj: Any | None = None
Expand Down Expand Up @@ -216,9 +216,9 @@ def __call__(self, *args, **kwargs) -> Any:
if isinstance(self._imported, type):
self._obj = self._imported(*args, **kwargs)
return self._obj
else:
elif callable(self._imported):
self._obj = self._imported
return self._obj(*args, **kwargs)
return self._obj(*args, **kwargs) # type: ignore[misc]

def __getattr__(self, v: str) -> Any:
"""Get an attribute on a super lazy object."""
Expand Down
Loading