Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/scripts/cla_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def main():

# Allowlist for bots
bot_patterns = [
"autofix-ci",
"dependabot",
"github-actions",
"renovate",
Expand Down
51 changes: 33 additions & 18 deletions tests/test_licensing.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
"""Tests for cortex/licensing.py - License management and feature gating."""

import json
import pytest
from datetime import datetime, timezone, timedelta
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch

import pytest

from cortex.licensing import (
FEATURE_NAMES,
FEATURE_REQUIREMENTS,
LICENSE_FILE,
FeatureNotAvailableError,
FeatureTier,
LicenseInfo,
FEATURE_REQUIREMENTS,
FEATURE_NAMES,
_get_hostname,
activate_license,
check_feature,
get_license_info,
get_license_tier,
check_feature,
require_feature,
activate_license,
show_license_status,
show_upgrade_prompt,
FeatureNotAvailableError,
LICENSE_FILE,
_get_hostname,
)


Expand Down Expand Up @@ -143,19 +144,20 @@ class TestGetLicenseInfo:
def reset_cache(self):
"""Reset license cache before each test."""
import cortex.licensing as lic

lic._cached_license = None
yield
lic._cached_license = None

def test_returns_license_info(self):
"""Should return LicenseInfo object."""
with patch.object(Path, 'exists', return_value=False):
with patch.object(Path, "exists", return_value=False):
info = get_license_info()
assert isinstance(info, LicenseInfo)

def test_default_community_tier(self):
"""Should default to community tier when no license file."""
with patch.object(Path, 'exists', return_value=False):
with patch.object(Path, "exists", return_value=False):
info = get_license_info()
assert info.tier == FeatureTier.COMMUNITY

Expand All @@ -174,14 +176,14 @@ def test_reads_license_file(self, tmp_path):
license_file = tmp_path / "license.key"
license_file.write_text(json.dumps(license_data))

with patch.object(lic, 'LICENSE_FILE', license_file):
with patch.object(lic, "LICENSE_FILE", license_file):
info = get_license_info()
assert info.tier == "pro"
assert info.organization == "Test Org"

def test_caches_result(self):
"""Should cache license info."""
with patch.object(Path, 'exists', return_value=False):
with patch.object(Path, "exists", return_value=False):
info1 = get_license_info()
info2 = get_license_info()
assert info1 is info2
Expand All @@ -194,13 +196,15 @@ class TestCheckFeature:
def reset_cache(self):
"""Reset license cache before each test."""
import cortex.licensing as lic

lic._cached_license = None
yield
lic._cached_license = None

def test_community_features_allowed(self):
"""Community tier should access community features."""
import cortex.licensing as lic

lic._cached_license = LicenseInfo(tier=FeatureTier.COMMUNITY)

# Unknown features default to community
Expand All @@ -209,20 +213,23 @@ def test_community_features_allowed(self):
def test_pro_feature_blocked_for_community(self):
"""Community tier should not access pro features."""
import cortex.licensing as lic

lic._cached_license = LicenseInfo(tier=FeatureTier.COMMUNITY)

assert check_feature("cloud_llm", silent=True) is False

def test_pro_feature_allowed_for_pro(self):
"""Pro tier should access pro features."""
import cortex.licensing as lic

lic._cached_license = LicenseInfo(tier=FeatureTier.PRO)

assert check_feature("cloud_llm", silent=True) is True

def test_enterprise_feature_allowed_for_enterprise(self):
"""Enterprise tier should access all features."""
import cortex.licensing as lic

lic._cached_license = LicenseInfo(tier=FeatureTier.ENTERPRISE)

assert check_feature("sso", silent=True) is True
Expand All @@ -231,6 +238,7 @@ def test_enterprise_feature_allowed_for_enterprise(self):
def test_shows_upgrade_prompt(self, capsys):
"""Should show upgrade prompt when feature blocked."""
import cortex.licensing as lic

lic._cached_license = LicenseInfo(tier=FeatureTier.COMMUNITY)

check_feature("cloud_llm", silent=False)
Expand All @@ -245,13 +253,15 @@ class TestRequireFeatureDecorator:
def reset_cache(self):
"""Reset license cache before each test."""
import cortex.licensing as lic

lic._cached_license = None
yield
lic._cached_license = None

def test_allows_when_feature_available(self):
"""Should allow function call when feature available."""
import cortex.licensing as lic

lic._cached_license = LicenseInfo(tier=FeatureTier.PRO)

@require_feature("cloud_llm")
Expand All @@ -263,6 +273,7 @@ def test_func():
def test_raises_when_feature_blocked(self):
"""Should raise FeatureNotAvailableError when feature blocked."""
import cortex.licensing as lic

lic._cached_license = LicenseInfo(tier=FeatureTier.COMMUNITY)

@require_feature("cloud_llm")
Expand Down Expand Up @@ -297,6 +308,7 @@ class TestActivateLicense:
def reset_cache(self):
"""Reset license cache before each test."""
import cortex.licensing as lic

lic._cached_license = None
yield
lic._cached_license = None
Expand All @@ -314,8 +326,8 @@ def test_successful_activation(self, tmp_path):
"organization": "Test Org",
}

with patch.object(lic, 'LICENSE_FILE', license_file):
with patch('httpx.post', return_value=mock_response):
with patch.object(lic, "LICENSE_FILE", license_file):
with patch("httpx.post", return_value=mock_response):
result = activate_license("test-key-123")

assert result is True
Expand All @@ -329,7 +341,7 @@ def test_failed_activation(self):
"error": "Invalid key",
}

with patch('httpx.post', return_value=mock_response):
with patch("httpx.post", return_value=mock_response):
result = activate_license("invalid-key")

assert result is False
Expand All @@ -338,7 +350,7 @@ def test_network_error(self):
"""Should handle network errors gracefully."""
import httpx

with patch('httpx.post', side_effect=httpx.HTTPError("Network error")):
with patch("httpx.post", side_effect=httpx.HTTPError("Network error")):
result = activate_license("test-key")

assert result is False
Expand All @@ -351,13 +363,15 @@ class TestShowLicenseStatus:
def reset_cache(self):
"""Reset license cache before each test."""
import cortex.licensing as lic

lic._cached_license = None
yield
lic._cached_license = None

def test_shows_community_status(self, capsys):
"""Should show community tier status."""
import cortex.licensing as lic

lic._cached_license = LicenseInfo(tier=FeatureTier.COMMUNITY)

show_license_status()
Expand All @@ -369,6 +383,7 @@ def test_shows_community_status(self, capsys):
def test_shows_pro_status(self, capsys):
"""Should show pro tier status."""
import cortex.licensing as lic

lic._cached_license = LicenseInfo(
tier=FeatureTier.PRO,
organization="Test Corp",
Expand Down
Loading