diff --git a/wavefront/server/apps/floconsole/floconsole/config.ini b/wavefront/server/apps/floconsole/floconsole/config.ini index 41ea088e..b6e71275 100644 --- a/wavefront/server/apps/floconsole/floconsole/config.ini +++ b/wavefront/server/apps/floconsole/floconsole/config.ini @@ -8,6 +8,9 @@ db_name = ${CONSOLE_DB_NAME} [env_config] app_env = ${APP_ENV} +[cloud_config] +cloud_provider = ${CLOUD_PROVIDER} + [jwt_token] token_expiry=${TOKEN_EXPIRY} temporary_token_expiry=${TEMPORARY_TOKEN_EXPIRY} diff --git a/wavefront/server/apps/floconsole/floconsole/di/application_container.py b/wavefront/server/apps/floconsole/floconsole/di/application_container.py index 081c386e..8a74f040 100644 --- a/wavefront/server/apps/floconsole/floconsole/di/application_container.py +++ b/wavefront/server/apps/floconsole/floconsole/di/application_container.py @@ -66,12 +66,8 @@ class ApplicationContainer(containers.DeclarativeContainer): app_user_repository=app_user_repository, ) - kms_service = providers.Selector( - config.jwt_token.enable_cloud_kms, - true=providers.Singleton( - FloKmsService, cloud_provider=config.cloud_config.cloud_provider - ), - false=providers.Object(None), # No KMS service if cloud KMS is not enabled + kms_service = providers.Singleton( + FloKmsService, cloud_provider=config.cloud_config.cloud_provider ) token_service = providers.Singleton( diff --git a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/main.py b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/main.py index 50a5931c..ee0441dd 100644 --- a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/main.py +++ b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/main.py @@ -15,11 +15,15 @@ def main(): cache_manager = CacheManager(namespace='rag') encryption_service = None if ( - (CLOUD_PROVIDER == 'aws' and os.getenv('AWS_KMS_ARN') is not None) - or CLOUD_PROVIDER == 'gcp' - and ( - os.getenv('GCP_KMS_KEY_RING') is not None - and os.getenv('GCP_KMS_CRYPTO_KEY') is not None + (CLOUD_PROVIDER == 'aws' and os.getenv('AWS_KMS_ARN')) + or ( + CLOUD_PROVIDER == 'gcp' + and (os.getenv('GCP_KMS_KEY_RING') and os.getenv('GCP_KMS_CRYPTO_KEY')) + ) + or ( + CLOUD_PROVIDER == 'azure' + and os.getenv('AZURE_KEY_VAULT_URL') + and os.getenv('AZURE_KEY_VAULT_KEY_NAME') ) ): encryption_service = FloKmsService(cloud_provider=CLOUD_PROVIDER) diff --git a/wavefront/server/modules/auth_module/auth_module/auth_container.py b/wavefront/server/modules/auth_module/auth_module/auth_container.py index b277e409..1737a03f 100644 --- a/wavefront/server/modules/auth_module/auth_module/auth_container.py +++ b/wavefront/server/modules/auth_module/auth_module/auth_container.py @@ -37,12 +37,8 @@ class AuthContainer(containers.DeclarativeContainer): db_client=db_client, ) - kms_service = providers.Selector( - config.jwt_token.enable_cloud_kms, - true=providers.Singleton( - FloKmsService, cloud_provider=config.cloud_config.cloud_provider - ), - false=providers.Object(None), # No KMS service if cloud KMS is not enabled + kms_service = providers.Singleton( + FloKmsService, cloud_provider=config.cloud_config.cloud_provider ) token_service = providers.Singleton( diff --git a/wavefront/server/packages/flo_cloud/flo_cloud/azure/__init__.py b/wavefront/server/packages/flo_cloud/flo_cloud/azure/__init__.py index 4eef673d..c9ebfe44 100644 --- a/wavefront/server/packages/flo_cloud/flo_cloud/azure/__init__.py +++ b/wavefront/server/packages/flo_cloud/flo_cloud/azure/__init__.py @@ -1,4 +1,9 @@ +import logging + from .blob_storage import AzureBlobStorage from .storage_queue import StorageQueue +from .key_vault import AzureKMS + +logging.getLogger('azure').setLevel(logging.WARNING) -__all__ = ['AzureBlobStorage', 'StorageQueue'] +__all__ = ['AzureBlobStorage', 'AzureKMS', 'StorageQueue'] diff --git a/wavefront/server/packages/flo_cloud/flo_cloud/azure/key_vault.py b/wavefront/server/packages/flo_cloud/flo_cloud/azure/key_vault.py new file mode 100644 index 00000000..7bf441d0 --- /dev/null +++ b/wavefront/server/packages/flo_cloud/flo_cloud/azure/key_vault.py @@ -0,0 +1,117 @@ +import os +from typing import Optional + +from azure.identity import ClientSecretCredential, DefaultAzureCredential +from azure.keyvault.keys import KeyClient +from azure.keyvault.keys.crypto import ( + CryptographyClient, + EncryptionAlgorithm, + SignatureAlgorithm, +) +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers + +from .._types import FloKMS + + +class AzureKMS(FloKMS): + """Azure Key Vault implementation of FloKMS. + + Authentication modes (same as AzureBlobStorage): + 1. Service Principal — provide client_id, client_secret, tenant_id explicitly, + or set AZURE_CLIENT_ID / AZURE_CLIENT_SECRET / AZURE_TENANT_ID env vars. + 2. DefaultAzureCredential — falls back to Workload Identity, Managed Identity, + Azure CLI, etc. + + Required env vars: + AZURE_KEY_VAULT_URL — e.g. https://my-vault.vault.azure.net/ + AZURE_KEY_VAULT_KEY_NAME — name of the RSA key in the vault + + Optional env var: + AZURE_KEY_VAULT_KEY_VERSION — specific key version; omit to use the latest + """ + + def __init__( + self, + vault_url: Optional[str] = None, + key_name: Optional[str] = None, + key_version: Optional[str] = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, + tenant_id: Optional[str] = None, + ): + resolved_vault_url = vault_url or os.environ.get('AZURE_KEY_VAULT_URL') + resolved_key_name = key_name or os.environ.get('AZURE_KEY_VAULT_KEY_NAME') + resolved_key_version = key_version or os.environ.get( + 'AZURE_KEY_VAULT_KEY_VERSION' + ) + + if not resolved_vault_url: + raise ValueError( + 'vault_url must be provided or AZURE_KEY_VAULT_URL must be set' + ) + if not resolved_key_name: + raise ValueError( + 'key_name must be provided or AZURE_KEY_VAULT_KEY_NAME must be set' + ) + + creds_provided = [client_id, client_secret, tenant_id] + if all(creds_provided): + credential = ClientSecretCredential( + tenant_id=tenant_id, + client_id=client_id, + client_secret=client_secret, + ) + elif any(creds_provided): + raise ValueError( + 'Partial credentials provided. Supply all of client_id, ' + 'client_secret, and tenant_id, or none to use DefaultAzureCredential.' + ) + else: + credential = DefaultAzureCredential() + + self._key_name = resolved_key_name + self._key_version = resolved_key_version + self.key_client = KeyClient(vault_url=resolved_vault_url, credential=credential) + key = self.key_client.get_key(resolved_key_name, version=resolved_key_version) + self.crypto_client = CryptographyClient(key, credential=credential) + + def encrypt(self, plaintext: str) -> bytes: + if isinstance(plaintext, str): + plaintext = plaintext.encode('utf-8') + result = self.crypto_client.encrypt(EncryptionAlgorithm.rsa_oaep_256, plaintext) + return result.ciphertext + + def decrypt(self, ciphertext: str) -> bytes: + if isinstance(ciphertext, str): + ciphertext = ciphertext.encode('utf-8') + result = self.crypto_client.decrypt( + EncryptionAlgorithm.rsa_oaep_256, ciphertext + ) + return result.plaintext + + def sign(self, message: bytes, **kwargs) -> bytes: + algorithm = kwargs.get('signing_algorithm', SignatureAlgorithm.ps256) + result = self.crypto_client.sign(algorithm, message) + return result.signature + + def verify(self, message: bytes, signature: bytes, **kwargs) -> bool: + algorithm = kwargs.get('signing_algorithm', SignatureAlgorithm.ps256) + result = self.crypto_client.verify(algorithm, message, signature) + return result.is_valid + + def get_public_key_pem(self, **kwargs) -> str | bytes: + key = self.key_client.get_key(self._key_name, version=self._key_version) + jwk = key.key + + # Decode the JWK RSA public key components (big-endian bytes) to integers + n = int.from_bytes(jwk.n, byteorder='big') + e = int.from_bytes(jwk.e, byteorder='big') + + public_key = RSAPublicNumbers(e=e, n=n).public_key(default_backend()) + pem_bytes = public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return pem_bytes.decode('utf-8') diff --git a/wavefront/server/packages/flo_cloud/flo_cloud/kms.py b/wavefront/server/packages/flo_cloud/flo_cloud/kms.py index 48e73136..c77eae9a 100644 --- a/wavefront/server/packages/flo_cloud/flo_cloud/kms.py +++ b/wavefront/server/packages/flo_cloud/flo_cloud/kms.py @@ -1,4 +1,5 @@ from .aws.kms import AwsKMS +from .azure.key_vault import AzureKMS from .gcp.kms import GcpKMS from ._types import CloudProvider, FloKMS @@ -13,6 +14,8 @@ def __get_kms_client(self) -> FloKMS: return AwsKMS() elif self.cloud_provider == CloudProvider.GCP.value: return GcpKMS() + elif self.cloud_provider == CloudProvider.AZURE.value: + return AzureKMS() else: raise ValueError(f'Unsupported cloud provider: {self.cloud_provider}') diff --git a/wavefront/server/packages/flo_cloud/pyproject.toml b/wavefront/server/packages/flo_cloud/pyproject.toml index 7692bb21..99bd6ea5 100644 --- a/wavefront/server/packages/flo_cloud/pyproject.toml +++ b/wavefront/server/packages/flo_cloud/pyproject.toml @@ -9,6 +9,7 @@ readme = "README.md" requires-python = ">=3.11" dependencies = [ "azure-identity>=1.17.0", + "azure-keyvault-keys>=4.9.0", "azure-storage-blob>=12.20.0", "azure-storage-queue>=12.10.0", "boto3<=1.38.40", diff --git a/wavefront/server/uv.lock b/wavefront/server/uv.lock index 067eaa25..f7bf7aa7 100644 --- a/wavefront/server/uv.lock +++ b/wavefront/server/uv.lock @@ -523,6 +523,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/75/54/81683b6756676a22e037b209695b08008258e603f7e47c56834029c5922a/azure_identity-1.25.0-py3-none-any.whl", hash = "sha256:becaec086bbdf8d1a6aa4fb080c2772a0f824a97d50c29637ec8cc4933f1e82d", size = 190861, upload-time = "2025-09-12T01:30:06.474Z" }, ] +[[package]] +name = "azure-keyvault-keys" +version = "4.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "azure-core" }, + { name = "cryptography" }, + { name = "isodate" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/69/ed/450c9389d76be1a95a056528ec2b832a3721858dd47b1f4eb12dab7060a1/azure_keyvault_keys-4.11.0.tar.gz", hash = "sha256:f257b1917a2c3a88983e3f5675a6419449eb262318888d5b51e1cb3bed79779a", size = 241309, upload-time = "2025-06-16T22:52:04.296Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/ac/fa42e6b316712604a63bf7b3cb60d619d92890e038b87e1b4bba7437bc36/azure_keyvault_keys-4.11.0-py3-none-any.whl", hash = "sha256:fa5febd5805f0fed4c0a1d13c9096081c72a6fa36ccae1299a137f34280eda53", size = 191303, upload-time = "2025-06-16T22:52:06.1Z" }, +] + [[package]] name = "azure-storage-blob" version = "12.28.0" @@ -1419,6 +1434,7 @@ version = "0.1.0" source = { editable = "packages/flo_cloud" } dependencies = [ { name = "azure-identity" }, + { name = "azure-keyvault-keys" }, { name = "azure-storage-blob" }, { name = "azure-storage-queue" }, { name = "boto3" }, @@ -1433,6 +1449,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "azure-identity", specifier = ">=1.17.0" }, + { name = "azure-keyvault-keys", specifier = ">=4.9.0" }, { name = "azure-storage-blob", specifier = ">=12.20.0" }, { name = "azure-storage-queue", specifier = ">=12.10.0" }, { name = "boto3", specifier = "<=1.38.40" },