Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 140 additions & 52 deletions cmd/server/internal/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const (
TenantNotFound = "tenant not found"
)

const SubscriptionDomain = "subscription domain"
const ErrorOccurred = "Error occurred "
const InvalidRequestMethod = "invalid request method"
const AuthorizationCheckFailed = "authorization check failed"
Expand All @@ -81,6 +82,8 @@ const (
type RequestInfo struct {
// One of "SMS" or "SaaS"
subscriptionType subscriptionType
// subscription domain from the subscription request, used for constructing the subscription URL. If not present, existing fallbacks will be used
subscriptionDomain string
// payload Details
payload *payloadDetails
// header details
Expand Down Expand Up @@ -129,12 +132,12 @@ type callbackResponse struct {

type SaaSCallbackResponse struct {
callbackResponse `json:",inline"`
SubscriptionUrl string `json:"subscriptionUrl"`
SubscriptionUrl string `json:"subscriptionUrl,omitempty"`
}

type SmsCallbackResponse struct {
callbackResponse `json:",inline"`
ApplicationUrl string `json:"applicationUrl"`
ApplicationUrl string `json:"applicationUrl,omitempty"`
}

type CallbackReqInfo struct {
Expand All @@ -150,8 +153,9 @@ type OAuthResponse struct {
}

type tenantInfo struct {
tenantId string
tenantSubDomain string
tenantId string
tenantSubDomain string
subscriptionDomain string
}

func (s *SubscriptionHandler) CreateTenant(reqInfo *RequestInfo) *Result {
Expand All @@ -173,6 +177,12 @@ func (s *SubscriptionHandler) CreateTenant(reqInfo *RequestInfo) *Result {
return &Result{Tenant: nil, Message: err.Error()}
}

appUrl, err := s.getAppURL(reqInfo.subscriptionDomain, reqInfo.payload.subdomain, ca)
if err != nil {
util.LogError(err, ErrorOccurred, TenantProvisioning, ca, nil)
return &Result{Tenant: nil, Message: "Error constructing subscription URL: " + err.Error()}
}

// Check if A CRO for CAPTenant already exists
tenant := s.getTenantByBtpAppIdentifier(ca.Spec.GlobalAccountId, reqInfo.payload.appName, reqInfo.payload.tenantId, ca.Namespace, TenantProvisioning).Tenant

Expand All @@ -193,28 +203,28 @@ func (s *SubscriptionHandler) CreateTenant(reqInfo *RequestInfo) *Result {

// TODO: consider retrying tenant creation if it is in Error state
if tenant != nil {
tenantIn := tenantInfo{tenantId: reqInfo.payload.tenantId, tenantSubDomain: reqInfo.payload.subdomain}
tenantIn := tenantInfo{tenantId: reqInfo.payload.tenantId, tenantSubDomain: reqInfo.payload.subdomain, subscriptionDomain: reqInfo.subscriptionDomain}
callbackReqInfo := s.getCallbackReqInfo(reqInfo.subscriptionType, reqInfo.headerDetails.callbackInfo, saasData, smsData)
s.initializeCallback(tenant.Name, ca, callbackReqInfo, tenantIn, true)
}

// Tenant created/exists
message := func(isCreated, isUpdated bool) string {
if isCreated {
return ResourceCreated
} else if isUpdated {
return ResourceUpdated
} else {
return ResourceFound
}
s.initializeCallback(appUrl, tenant.Name, ca, callbackReqInfo, tenantIn, true)
}

if created {
util.LogInfo("Tenant successfully created", TenantProvisioning, ca, tenant, "message", message(created, updated))
util.LogInfo("Tenant successfully created", TenantProvisioning, ca, tenant, "message", getMessage(created, updated))
} else if updated {
util.LogInfo("Tenant successfully updated", TenantProvisioning, ca, tenant, "message", message(created, updated))
util.LogInfo("Tenant successfully updated", TenantProvisioning, ca, tenant, "message", getMessage(created, updated))
}
return &Result{Tenant: tenant, Message: getMessage(created, updated)}
}

func getMessage(isCreated, isUpdated bool) string {
// Tenant created/exists
if isCreated {
return ResourceCreated
} else if isUpdated {
return ResourceUpdated
} else {
return ResourceFound
}
return &Result{Tenant: tenant, Message: message(created, updated)}
}

func (s *SubscriptionHandler) createTenant(reqInfo *RequestInfo, ca *v1alpha1.CAPApplication) (tenant *v1alpha1.CAPTenant, err error) {
Expand Down Expand Up @@ -493,9 +503,9 @@ func (s *SubscriptionHandler) DeleteTenant(reqInfo *RequestInfo) *Result {
return &Result{Tenant: nil, Message: err.Error()}
}

tenantIn := tenantInfo{tenantId: reqInfo.payload.tenantId, tenantSubDomain: reqInfo.payload.subdomain}
tenantIn := tenantInfo{tenantId: reqInfo.payload.tenantId, tenantSubDomain: reqInfo.payload.subdomain, subscriptionDomain: reqInfo.subscriptionDomain}
callbackReqInfo := s.getCallbackReqInfo(reqInfo.subscriptionType, reqInfo.headerDetails.callbackInfo, saasData, smsData)
s.initializeCallback(tenant.Name, ca, callbackReqInfo, tenantIn, false)
s.initializeCallback("", tenant.Name, ca, callbackReqInfo, tenantIn, false)

return &Result{Tenant: tenant, Message: ResourceDeleted}
}
Expand Down Expand Up @@ -591,20 +601,12 @@ func (s *SubscriptionHandler) checkCertIssuerAndSubject(xForwardedClientCert str
return nil
}

func (s *SubscriptionHandler) initializeCallback(tenantName string, ca *v1alpha1.CAPApplication, callbackReqInfo *CallbackReqInfo, tenantIn tenantInfo, isProvisioning bool) {
subscriptionDomain := ca.Annotations[AnnotationSubscriptionDomain]
if subscriptionDomain == "" {
subscriptionDomain = s.getPrimaryDomain(ca)
}

appUrl := "https://" + tenantIn.tenantSubDomain + "." + subscriptionDomain
asyncCallbackPath := callbackReqInfo.CallbackPath
util.LogInfo("Callback initialized", TenantProvisioning, ca, nil, "subscription URL", appUrl, "async callback path", asyncCallbackPath, "tenantName", tenantName)

func (s *SubscriptionHandler) initializeCallback(appUrl, tenantName string, ca *v1alpha1.CAPApplication, callbackReqInfo *CallbackReqInfo, tenantIn tenantInfo, isProvisioning bool) {
step := TenantProvisioning
if !isProvisioning {
step = TenantDeprovisioning
}
util.LogInfo("Callback initialized", step, ca, nil, "subscription URL", appUrl, "async callback path", callbackReqInfo.CallbackPath, "tenantName", tenantName)

go func() {
// create a context for tenant checks and outgoing requests
Expand Down Expand Up @@ -635,10 +637,93 @@ func (s *SubscriptionHandler) initializeCallback(tenantName string, ca *v1alpha1
} else {
additionalOutput = nil
}
s.handleAsyncCallback(ctx, callbackReqInfo, status, asyncCallbackPath, appUrl, additionalOutput, isProvisioning)

callbackResponse := getCallbackResponseStatus(status, isProvisioning, additionalOutput)

s.handleAsyncCallback(ctx, callbackReqInfo, callbackReqInfo.CallbackPath, appUrl, callbackResponse)
}()
}

func getCallbackResponseStatus(status bool, isProvisioning bool, additionalOutput *map[string]any) *callbackResponse {
var responseStatus string
var message string
if status {
responseStatus = CallbackSucceeded
if isProvisioning {
message = ProvisioningSucceededMessage
} else {
message = DeprovisioningSucceededMessage
}
} else {
responseStatus = CallbackFailed
if isProvisioning {
message = ProvisioningFailedMessage
} else {
message = DeprovisioningFailedMessage
}
}

return &callbackResponse{
Status: responseStatus,
Message: message,
AdditionalOutput: additionalOutput,
}
}

func (s *SubscriptionHandler) getAppURL(payloadSubscriptionDomain, tenantSubdomain string, ca *v1alpha1.CAPApplication) (string, error) {
needsValidaton := true
var subscriptionDomain string
// Check if subscription domain is provided in the request payload.
if payloadSubscriptionDomain != "" {
subscriptionDomain = payloadSubscriptionDomain
util.LogInfo("Using subscription domain from request payload", TenantProvisioning, ca, nil, SubscriptionDomain, subscriptionDomain)
} else {
// Fallback:
// First, check if subscription domain is provided in the CAPApplication annotation. If not, fallback to calculating the primary domain from the CAPApplication domain refs and use that as the subscription domain.
subscriptionDomain = ca.Annotations[AnnotationSubscriptionDomain]
if subscriptionDomain == "" {
subscriptionDomain = s.getPrimaryDomain(ca)
needsValidaton = false
util.LogInfo("Using subscription domain from fallback 'primary' calculation", TenantProvisioning, ca, nil, SubscriptionDomain, subscriptionDomain)
} else {
util.LogInfo("Using subscription domain from CAPApplication annotation", TenantProvisioning, ca, nil, SubscriptionDomain, subscriptionDomain)
}
}

if needsValidaton {
err := s.validateDomain(subscriptionDomain, ca.Namespace)
if err != nil {
return "", err
}
}

return "https://" + tenantSubdomain + "." + subscriptionDomain, nil
}

func (s *SubscriptionHandler) validateDomain(domain, namespace string) error {
// First check for Domains in the apps namespace
domainsList, err := s.Clientset.SmeV1alpha1().Domains(namespace).List(context.TODO(), metav1.ListOptions{})
if err != nil {
return err
}
for _, d := range domainsList.Items {
if d.Spec.Domain == domain {
return nil
}
}

// Check for ClusterDomains if not found in the namespace
clusterDomainsList, err := s.Clientset.SmeV1alpha1().ClusterDomains(metav1.NamespaceAll).List(context.TODO(), metav1.ListOptions{})
if err != nil {
return err
}
for _, cd := range clusterDomainsList.Items {
if cd.Spec.Domain == domain {
return nil
}
}

util.LogInfo("Waiting for async callback after checks...", step, ca, nil, "tenantName", tenantName)
return fmt.Errorf("domain %s not found in Domains or ClusterDomains", domain)
}

func (s *SubscriptionHandler) getPrimaryDomain(ca *v1alpha1.CAPApplication) string {
Expand Down Expand Up @@ -859,7 +944,7 @@ func prepareTokenRequest(ctx context.Context, callbackReqInfo *CallbackReqInfo,
return tokenReq, nil
}

func (s *SubscriptionHandler) handleAsyncCallback(ctx context.Context, callbackReqInfo *CallbackReqInfo, status bool, asyncCallbackPath string, appUrl string, additionalOutput *map[string]any, isProvisioning bool) {
func (s *SubscriptionHandler) handleAsyncCallback(ctx context.Context, callbackReqInfo *CallbackReqInfo, asyncCallbackPath, appUrl string, callbackResponse *callbackResponse) {
// Get OAuth token
tokenClient := s.httpClientGenerator.NewHTTPClient()
tokenReq, err := prepareTokenRequest(ctx, callbackReqInfo, tokenClient)
Expand All @@ -886,19 +971,8 @@ func (s *SubscriptionHandler) handleAsyncCallback(ctx context.Context, callbackR
}
defer tokenResponse.Body.Close()

checkMatch := func(match bool, trueVal string, falseVal string) string {
if match {
return trueVal
}
return falseVal
}

var payload []byte
callbackResponse := &callbackResponse{
Status: checkMatch(status, CallbackSucceeded, CallbackFailed),
Message: checkMatch(status, checkMatch(isProvisioning, ProvisioningSucceededMessage, DeprovisioningSucceededMessage), checkMatch(isProvisioning, ProvisioningFailedMessage, DeprovisioningFailedMessage)),
AdditionalOutput: additionalOutput,
}

switch callbackReqInfo.SubscriptionType {
case SMS:
payload, _ = json.Marshal(&SmsCallbackResponse{
Expand Down Expand Up @@ -983,8 +1057,8 @@ func (s *SubscriptionHandler) HandleSMSRequest(w http.ResponseWriter, req *http.
}

func ProcessRequest(req *http.Request, subscriptionType subscriptionType) (*RequestInfo, error) {
var subscriptionGUID, tenantId, subdomain, globalAccountId, providerSubaccountId, appName string
var jsonPayload map[string]any
var subscriptionGUID, tenantId, subdomain, globalAccountId, providerSubaccountId, appName, subscriptionDomain string
jsonPayload := map[string]any{}

if !(req.Method == http.MethodDelete && subscriptionType == SMS) {
decoder := json.NewDecoder(req.Body)
Expand All @@ -1010,6 +1084,7 @@ func ProcessRequest(req *http.Request, subscriptionType subscriptionType) (*Requ
rootApp := jsonPayload["rootApplication"].(map[string]any)
providerSubaccountId = rootApp["providerSubaccountId"].(string)
appName = rootApp["appName"].(string)
subscriptionDomain = getSubscriptionDomain(rootApp)
case http.MethodDelete:
// get paramater from URL
subscriptionGUID = req.URL.Query().Get("subscriptionGUID")
Expand All @@ -1032,6 +1107,7 @@ func ProcessRequest(req *http.Request, subscriptionType subscriptionType) (*Requ
globalAccountId = jsonPayload["globalAccountGUID"].(string)
providerSubaccountId = jsonPayload["providerSubaccountId"].(string)
appName = jsonPayload["subscriptionAppName"].(string)
subscriptionDomain = getSubscriptionDomain(jsonPayload)
}

payload := &payloadDetails{
Expand All @@ -1045,12 +1121,24 @@ func ProcessRequest(req *http.Request, subscriptionType subscriptionType) (*Requ
raw: &jsonPayload,
}
return &RequestInfo{
subscriptionType: subscriptionType,
payload: payload,
headerDetails: &headerDetails,
subscriptionType: subscriptionType,
subscriptionDomain: subscriptionDomain,
payload: payload,
headerDetails: &headerDetails,
}, nil
}

func getSubscriptionDomain(payload map[string]any) string {
if subscriptionParams, ok := payload["subscriptionParams"]; ok {
if subscriptionParamsMap, ok := subscriptionParams.(map[string]any); ok {
if subscriptionDomain, ok := subscriptionParamsMap["subscriptionDomain"]; ok {
return subscriptionDomain.(string)
}
}
}
return ""
}

func NewSubscriptionHandler(clientset versioned.Interface, kubeClienset kubernetes.Interface) *SubscriptionHandler {
return &SubscriptionHandler{Clientset: clientset, KubeClienset: kubeClienset, httpClientGenerator: &httpClientGeneratorImpl{}}
}
Expand Down
10 changes: 3 additions & 7 deletions cmd/server/internal/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ func Test_provisioning(t *testing.T) {
{
name: "Provisioning Request valid (invalid clusterdomains)",
method: http.MethodPut,
body: `{"subscriptionAppName":"` + appName + `","globalAccountGUID":"` + globalAccountId + `","providerSubaccountId":"` + providerSubaccountId + `","subscriptionGUID":"` + subscriptionGUID + `","subscribedTenantId":"` + tenantId + `","subscribedSubdomain":"` + subDomain + `"}`,
body: `{"subscriptionAppName":"` + appName + `","globalAccountGUID":"` + globalAccountId + `","providerSubaccountId":"` + providerSubaccountId + `","subscriptionGUID":"` + subscriptionGUID + `","subscribedTenantId":"` + tenantId + `","subscribedSubdomain":"` + subDomain + `","subscriptionParams":""}`,
createCROs: true,
invalidClusterDomain: true,
expectedStatusCode: http.StatusAccepted,
Expand Down Expand Up @@ -1219,11 +1219,9 @@ func TestAsyncCallback(t *testing.T) {
subHandler.handleAsyncCallback(
ctx,
callbackReqInfo,
p.status,
"/async/callback",
"https://app.cluster.local",
p.additionalData,
p.isProvisioning,
getCallbackResponseStatus(p.status, p.isProvisioning, p.additionalData),
)
})
}
Expand All @@ -1245,11 +1243,9 @@ func TestAsyncCallback(t *testing.T) {
subHandler.handleAsyncCallback(
ctx,
callbackReqInfo,
p.status,
"/async/callback",
"https://app.cluster.local",
p.additionalData,
p.isProvisioning,
getCallbackResponseStatus(p.status, p.isProvisioning, p.additionalData),
)
})
}
Expand Down
Loading