Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions geoapi/celery_app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from celery import Celery
from datetime import timedelta
from geoapi.settings import settings
import socket


CELERY_CONNECTION_STRING = "amqp://{user}:{pwd}@{hostname}/{vhost}".format(
Expand Down Expand Up @@ -36,6 +37,24 @@

app.conf.task_default_queue = "default"

# Redis result backend settings
app.conf.result_backend_transport_options = {
'socket_keepalive': True,
'socket_keepalive_options': {
socket.TCP_KEEPIDLE: 60,
socket.TCP_KEEPCNT: 3,
socket.TCP_KEEPINTVL: 10,
},
'retry_on_timeout': True,
}

# RabbitMQ broker
app.conf.broker_transport_options = {
'confirm_publish': True,
'socket_keepalive': True,
}
app.conf.broker_heartbeat = 60

app.conf.beat_schedule = {
"refresh_projects_watch_content": {
"task": "geoapi.tasks.external_data.refresh_projects_watch_content",
Expand Down
5 changes: 5 additions & 0 deletions geoapi/custom/designsafe/project_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def get_system_users(database_session, user, system_id: str):

from geoapi.utils.external_apis import SystemUser

if system_id is None:
raise GetUsersForProjectNotSupported(
"System ID is None, cannot get users"
)

if not system_id.startswith("project-"):
raise GetUsersForProjectNotSupported(
f"System:{system_id} is not a project so unable to get users"
Expand Down
6 changes: 6 additions & 0 deletions geoapi/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def get_celery_engine():
pool_pre_ping=True, # Check connection health before using
pool_recycle=3600, # Replace connections after 1 hour
pool_timeout=60, # Wait up to 60s for connection
connect_args={
"keepalives": 1,
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 3,
},
)
return _celery_engine

Expand Down
79 changes: 35 additions & 44 deletions geoapi/services/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,29 +157,33 @@ def refresh_access_token(database_session, user: User):
RefreshTokenExpired: If the refresh token is expired.
RefreshTokenError: If there is a problem refreshing the token.
"""
username = user.username
tenant_id = user.tenant_id
user_id = user.id

if not user.has_unexpired_refresh_token():
logger.error(
f"Unable to refresh token for user:{user.username} tenant:{user.tenant_id}"
f"Unable to refresh token for user:{username} tenant:{tenant_id}"
f" as refresh token is expired (or possibly never existed)"
)
raise RefreshTokenExpired

try:
logger.info(
f"Refreshing token for user:{user.username}" f" tenant:{user.tenant_id}"
f"Refreshing token for user:{username}" f" tenant:{tenant_id}"
)
with database_session.begin_nested():
# Acquire lock by selecting the auth row for update
# to ensure that only one process is refreshing the tokens at a time
locked_auth = (
database_session.query(Auth)
.filter(Auth.user_id == user.id)
.filter(Auth.user_id == user_id)
.with_for_update()
.one()
)
logger.info(
f"Acquired auth for refreshing token for user:{user.username}"
f" tenant:{user.tenant_id}"
f"Acquired auth for refreshing token for user:{username}"
f" tenant:{tenant_id}"
)

# Check if the tokens were updated while we were getting the `locked_auth`
Expand All @@ -192,29 +196,12 @@ def refresh_access_token(database_session, user: User):
locked_auth.access_token_expires_at - current_time
) > buffer_time:
logger.info(
f"No need to refresh token for user:{user.username}"
f" tenant:{user.tenant_id} as it was recently refreshed"
)
return

# Check if the tokens were updated while we were getting the `locked_auth`
if locked_auth.access_token_expires_at:
current_time = datetime.utcnow().replace(tzinfo=None)
# Make sure `locked_auth.access_token_expires_at` is naive datetime
access_token_expires_at = (
locked_auth.access_token_expires_at.replace(tzinfo=None)
)
buffer_time = timedelta(
seconds=jwt_utils.BUFFER_TIME_WHEN_CHECKING_IF_ACCESS_TOKEN_WAS_RECENTLY_REFRESHED
)
if (access_token_expires_at - current_time) > buffer_time:
logger.info(
f"No need to refresh token for user:{user.username}"
f" tenant:{user.tenant_id} as it was recently refreshed"
f"No need to refresh token for user:{username}"
f" tenant:{tenant_id} as it was recently refreshed"
)
return

tapis_server = get_tapis_api_server(user.tenant_id)
tapis_server = get_tapis_api_server(tenant_id)
body = {
"refresh_token": locked_auth.refresh_token,
}
Expand All @@ -232,38 +219,42 @@ def refresh_access_token(database_session, user: User):
locked_auth.refresh_token_expires_at = data["refresh_token"][
"expires_at"
]
database_session.commit()
logger.info(
f"Finished refreshing token for user:{user.username}"
f" tenant:{user.tenant_id}"
)
jwt_utils.send_refreshed_token_websocket(
user,
{
"username": user.username,
"authToken": {
"token": locked_auth.access_token,
"expiresAt": locked_auth.access_token_expires_at,
},
},
)
else:
logger.error(
f"Problem refreshing token for user:{user.username}"
f" tenant:{user.tenant_id}: {response}, {response.text}"
f"Problem refreshing token for user:{username}"
f" tenant:{tenant_id}: {response}, {response.text}"
)
raise RefreshTokenError

database_session.commit()

logger.info(
f"Finished refreshing token for user:{username}"
f" tenant:{tenant_id}"
)

# Re-query the updated user after the transaction is committed
# (so that the caller has the latest state which includes the updated auth token)
database_session.refresh(user)

jwt_utils.send_refreshed_token_websocket(
user,
{
"username": username,
"authToken": {
"token": locked_auth.access_token,
"expiresAt": locked_auth.access_token_expires_at,
},
},
)
except InvalidRequestError as ire:
logger.exception(
f"Transaction error during token refresh for user:{user.username}: {str(ire)}"
f"Transaction error during token refresh for user:{username}: {str(ire)}"
)
raise RefreshTokenError from ire
except Exception as e:
database_session.rollback()
logger.exception(
f"Error during token refresh for user:{user.username}: {str(e)}"
f"Error during token refresh for user:{username}: {str(e)}"
)
raise RefreshTokenError from e
65 changes: 25 additions & 40 deletions geoapi/tasks/file_location_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
DESIGNSAFE_PUBLISHED_SYSTEM,
"designsafe.storage.community",
]
BATCH_SIZE = 500 # Commit every 500 items


def build_file_index_from_tapis(
client, system_id: str, path: str = "/"
Expand Down Expand Up @@ -175,10 +173,11 @@ def determine_if_exists_in_tree(
# Use first match - path only (no system)
path = "/" + published_paths[0].lstrip("/")

logger.debug(f"Asset '{current_file_path}' found in system at: {path}")
logger.info(f"Asset '{current_file_path}' found in system at: {path}")

return (True, path)

logger.info(f"Asset '{current_file_path}' not found in system at.")
return (False, None)


Expand Down Expand Up @@ -284,22 +283,17 @@ def check_and_update_designsafe_project_id(
item: Union[FeatureAsset, TileServer],
session,
user,
) -> None:
) -> bool:
"""
Check and update the designsafe_project_id for an item based on its current_system.
Uses module-level caching to minimize API calls to DesignSafe.

Args:
item: FeatureAsset or TileServer to update
session: Database session
user: User for API calls
Returns True if designsafe_project_id was set, False if no change was made.
"""

if item.designsafe_project_id:
logger.debug(
f"Nothing to do as item has designsafe_project_id:{item.designsafe_project_id}"
)
return
return False

# Check if we can derive PRJ from published projects path
if (
Expand All @@ -309,50 +303,55 @@ def check_and_update_designsafe_project_id(
):
parts = item.original_path.split("/")
item.designsafe_project_id = parts[2] # PRJ-XXXX
return
return True

# Determine which system to use
system_to_check = item.original_system or item.current_system

if not system_to_check:
logger.debug(f"No system to check for {type(item).__name__}={item.id}")
return
return False

if not is_designsafe_project(system_to_check):
logger.debug(f"System {system_to_check} is not a DesignSafe project, skipping")
return
return False

designsafe_project_id = get_designsafe_project_id(session, user, system_to_check)
if designsafe_project_id:
logger.debug(f"Setting item's designsafe_project_id to {designsafe_project_id}")
item.designsafe_project_id = designsafe_project_id
return True
return False


def check_and_update_public_system(
item: Union[FeatureAsset, TileServer],
published_file_tree_cache: Dict,
session,
user,
) -> None:
) -> bool:
"""
Check if item is on a public system and update location if found in published tree.
Works for both FeatureAsset and TileServer.

Updates is_on_public_system and, if found, current_system/current_path.

Returns True if location was updated, False otherwise.
"""
item_type = type(item).__name__
item_id = item.id

# Skip if already on a public system
if item.current_system in PUBLIC_SYSTEMS:
item.is_on_public_system = True
return
return False

if item.current_system is None:
logger.warning(
f"We don't know the current system: {item_type}={item_id}"
f" original_path={item.original_path} original_system={item.original_system}"
f" current_path={item.current_path} current_system={item.current_system}"
)
return
return False

# Cache published file tree for this system if not already cached
if item.current_system not in published_file_tree_cache:
Expand All @@ -373,6 +372,8 @@ def check_and_update_public_system(
if exists and found_path:
item.current_system = DESIGNSAFE_PUBLISHED_SYSTEM
item.current_path = found_path
return True
return False


@app.task(queue="default")
Expand Down Expand Up @@ -492,9 +493,10 @@ def check_and_update_file_locations(user_id: int, project_id: int):
if designsafe_project_id:
project.designsafe_project_id = designsafe_project_id
session.add(project)
session.commit()

# Process each feature asset
for i, asset in enumerate(feature_assets):
for asset in feature_assets:
try:
# Update timestamp
asset.last_public_system_check = datetime.now(timezone.utc)
Expand Down Expand Up @@ -526,14 +528,7 @@ def check_and_update_file_locations(user_id: int, project_id: int):
)

session.add(asset)

# Commit in large batches for memory management (rare 5000+ item cases)
if (i + 1) % BATCH_SIZE == 0:
session.commit()
session.expire_all()
logger.info(
f"Batch: {i + 1}/{total_checked} processed, {len(failed_items)} errors"
)
session.commit()

except Exception as e:
error_msg = str(e)[:100]
Expand All @@ -554,7 +549,7 @@ def check_and_update_file_locations(user_id: int, project_id: int):
continue

# Process each tile server
for i, tile_server in enumerate(tile_servers, start=len(feature_assets)):
for tile_server in tile_servers:
try:
# Update timestamp
tile_server.last_public_system_check = datetime.now(timezone.utc)
Expand All @@ -579,14 +574,7 @@ def check_and_update_file_locations(user_id: int, project_id: int):
)

session.add(tile_server)

# Commit in large batches
if (i + 1) % BATCH_SIZE == 0:
session.commit()
session.expire_all()
logger.info(
f"Batch: {i + 1}/{total_checked} processed, {len(failed_items)} errors"
)
session.commit()

except Exception as e:
error_msg = str(e)[:100]
Expand All @@ -607,9 +595,6 @@ def check_and_update_file_locations(user_id: int, project_id: int):
session.rollback()
continue

# Final commit for remaining items
session.commit()

# Update final counts
file_location_check.completed_at = datetime.now(timezone.utc)
file_location_check.files_checked = total_checked - len(failed_items)
Expand Down Expand Up @@ -662,5 +647,5 @@ def check_and_update_file_locations(user_id: int, project_id: int):
logger.exception(f"Failed to mark task as failed: {cleanup_error}")
session.rollback()
# Re-raise to mark Celery task as failed as we can't even mark our internal
# task as faile
# task as failed
raise
Loading