Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/mqtt-notifications
231 changes: 231 additions & 0 deletions scripts/add_defaults_to_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
#!/usr/bin/env python3
"""
Script to add default values to plugin config schemas where missing.

This ensures that configs never start with None values, improving user experience
and preventing validation errors.
"""

import json
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional


def get_default_for_field(prop: Dict[str, Any]) -> Any:
"""
Determine a sensible default value for a field based on its type and constraints.

Args:
prop: Field property schema

Returns:
Default value or None if no default should be added
"""
prop_type = prop.get('type')

# Handle union types (array with multiple types)
if isinstance(prop_type, list):
# Use the first non-null type
prop_type = next((t for t in prop_type if t != 'null'), prop_type[0] if prop_type else 'string')

if prop_type == 'boolean':
return False

elif prop_type == 'number':
# For numbers, use minimum if available, or a sensible default
minimum = prop.get('minimum')
maximum = prop.get('maximum')

if minimum is not None:
return minimum
elif maximum is not None:
# Use a reasonable fraction of max (like 30% or minimum 1)
return max(1, int(maximum * 0.3))
else:
# No constraints, use 0
return 0

elif prop_type == 'integer':
# Similar to number
minimum = prop.get('minimum')
maximum = prop.get('maximum')

if minimum is not None:
return minimum
elif maximum is not None:
return max(1, int(maximum * 0.3))
else:
return 0

elif prop_type == 'string':
# Only add default for strings if it makes sense
# Check if there's an enum - use first value
enum_values = prop.get('enum')
if enum_values:
return enum_values[0]

# For optional string fields, empty string might be okay, but be cautious
# We'll skip adding defaults for strings unless explicitly needed
return None

elif prop_type == 'array':
# Empty array as default
return []

elif prop_type == 'object':
# Empty object - but we'll handle nested objects separately
return {}

return None


def should_add_default(prop: Dict[str, Any], field_path: str) -> bool:
"""
Determine if we should add a default value to this field.

Args:
prop: Field property schema
field_path: Dot-separated path to the field

Returns:
True if default should be added
"""
# Skip if already has a default
if 'default' in prop:
return False

# Skip secret fields (they should be user-provided)
if prop.get('x-secret', False):
return False

# Skip API keys and similar sensitive fields
field_name = field_path.split('.')[-1].lower()
sensitive_keywords = ['key', 'password', 'secret', 'token', 'auth', 'credential']
if any(keyword in field_name for keyword in sensitive_keywords):
return False

prop_type = prop.get('type')
if isinstance(prop_type, list):
prop_type = next((t for t in prop_type if t != 'null'), prop_type[0] if prop_type else None)

# Only add defaults for certain types
if prop_type in ('boolean', 'number', 'integer', 'array'):
return True

# For strings, only if there's an enum
if prop_type == 'string' and 'enum' in prop:
return True

return False


def add_defaults_recursive(schema: Dict[str, Any], path: str = "", modified: List[str] = None) -> bool:
"""
Recursively add default values to schema fields.

Args:
schema: Schema dictionary to modify
path: Current path in the schema (for logging)
modified: List to track which fields were modified

Returns:
True if any modifications were made
"""
if modified is None:
modified = []

if not isinstance(schema, dict) or 'properties' not in schema:
return False

changes_made = False

for key, prop in schema['properties'].items():
if not isinstance(prop, dict):
continue

current_path = f"{path}.{key}" if path else key

# Check nested objects
if prop.get('type') == 'object' and 'properties' in prop:
if add_defaults_recursive(prop, current_path, modified):
changes_made = True

# Add default if appropriate
if should_add_default(prop, current_path):
default_value = get_default_for_field(prop)
if default_value is not None:
prop['default'] = default_value
modified.append(current_path)
changes_made = True
print(f" Added default to {current_path}: {default_value} (type: {prop.get('type')})")

return changes_made


def process_schema_file(schema_path: Path) -> bool:
"""
Process a single schema file to add defaults.

Args:
schema_path: Path to the schema file

Returns:
True if file was modified
"""
print(f"\nProcessing: {schema_path}")

try:
with open(schema_path, 'r', encoding='utf-8') as f:
schema = json.load(f)
except Exception as e:
print(f" Error reading schema: {e}")
return False

modified_fields = []
changes_made = add_defaults_recursive(schema, modified=modified_fields)

if changes_made:
# Write back with pretty formatting
with open(schema_path, 'w', encoding='utf-8') as f:
json.dump(schema, f, indent=2, ensure_ascii=False)
f.write('\n') # Add trailing newline

print(f" ✓ Modified {len(modified_fields)} fields")
return True
else:
print(f" ✓ No changes needed")
return False


def main():
"""Main entry point."""
project_root = Path(__file__).parent.parent
plugins_dir = project_root / 'plugins'

if not plugins_dir.exists():
print(f"Error: Plugins directory not found: {plugins_dir}")
sys.exit(1)

# Find all config_schema.json files
schema_files = list(plugins_dir.rglob('config_schema.json'))

if not schema_files:
print("No config_schema.json files found")
sys.exit(0)

print(f"Found {len(schema_files)} schema files")

modified_count = 0
for schema_file in sorted(schema_files):
if process_schema_file(schema_file):
modified_count += 1

print(f"\n{'='*60}")
print(f"Summary: Modified {modified_count} out of {len(schema_files)} schema files")
print(f"{'='*60}")


if __name__ == '__main__':
main()

46 changes: 38 additions & 8 deletions src/plugin_system/schema_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,24 +394,54 @@ def _format_validation_error(self, error: ValidationError, plugin_id: Optional[s
def merge_with_defaults(self, config: Dict[str, Any], defaults: Dict[str, Any]) -> Dict[str, Any]:
"""
Merge configuration with defaults, preserving user values.
Also replaces None values with defaults to ensure config never has None from the start.

Args:
config: User configuration
defaults: Default values from schema

Returns:
Merged configuration with defaults applied where missing
Merged configuration with defaults applied where missing or None
"""
merged = defaults.copy()
merged = copy.deepcopy(defaults)

def deep_merge(target: Dict[str, Any], source: Dict[str, Any]) -> None:
"""Recursively merge source into target."""
def deep_merge(target: Dict[str, Any], source: Dict[str, Any], default_dict: Dict[str, Any]) -> None:
"""Recursively merge source into target, replacing None with defaults."""
for key, value in source.items():
default_value = default_dict.get(key)

if key in target and isinstance(target[key], dict) and isinstance(value, dict):
deep_merge(target[key], value)
# Both are dicts, recursively merge
if isinstance(default_value, dict):
deep_merge(target[key], value, default_value)
else:
deep_merge(target[key], value, {})
elif value is None and default_value is not None:
# Value is None and we have a default, use the default
target[key] = copy.deepcopy(default_value) if isinstance(default_value, (dict, list)) else default_value
else:
target[key] = value

deep_merge(merged, config)
# Normal merge: user value takes precedence (copy if dict/list)
if isinstance(value, (dict, list)):
target[key] = copy.deepcopy(value)
else:
target[key] = value

deep_merge(merged, config, defaults)

# Final pass: replace any remaining None values at any level with defaults
def replace_none_with_defaults(target: Dict[str, Any], default_dict: Dict[str, Any]) -> None:
"""Recursively replace None values with defaults."""
for key in list(target.keys()):
value = target[key]
default_value = default_dict.get(key)

if value is None and default_value is not None:
# Replace None with default
target[key] = copy.deepcopy(default_value) if isinstance(default_value, (dict, list)) else default_value
elif isinstance(value, dict) and isinstance(default_value, dict):
# Recursively process nested dicts
replace_none_with_defaults(value, default_value)

replace_none_with_defaults(merged, defaults)
return merged

33 changes: 9 additions & 24 deletions web_interface/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,14 @@
app.secret_key = os.urandom(24)
config_manager = ConfigManager()

# Initialize CSRF protection (optional for local-only, but recommended for defense-in-depth)
try:
from flask_wtf.csrf import CSRFProtect
csrf = CSRFProtect(app)
# Exempt SSE streams from CSRF (read-only)
from functools import wraps
from flask import request

def csrf_exempt(f):
"""Decorator to exempt a route from CSRF protection."""
f.csrf_exempt = True
return f

# Mark SSE streams as exempt
@app.before_request
def check_csrf_exempt():
"""Check if route should be exempt from CSRF."""
if request.endpoint and 'stream' in request.endpoint:
# SSE streams are read-only, exempt from CSRF
pass
except ImportError:
# flask-wtf not installed, CSRF protection disabled
csrf = None
pass
# CSRF protection disabled for local-only application
# CSRF is designed for internet-facing web apps to prevent cross-site request forgery.
# For a local-only Raspberry Pi application, the threat model is different:
# - If an attacker has network access to perform CSRF, they have other attack vectors
# - All API endpoints are programmatic (HTMX/fetch) and don't include CSRF tokens
# - Forms use HTMX which doesn't automatically include CSRF tokens
# If you need CSRF protection (e.g., exposing to internet), properly implement CSRF tokens in HTMX forms
csrf = None

# Initialize rate limiting (prevent accidental abuse, not security)
try:
Expand Down Expand Up @@ -543,6 +527,7 @@ def stream_logs():
csrf.exempt(stream_stats)
csrf.exempt(stream_display)
csrf.exempt(stream_logs)
# Note: api_v3 blueprint is exempted above after registration

if limiter:
limiter.limit("20 per minute")(stream_stats)
Expand Down
Loading