diff --git a/src/AzureClient/AzureClient.cs b/src/AzureClient/AzureClient.cs index 27ebd616f1..c03ad4a61f 100644 --- a/src/AzureClient/AzureClient.cs +++ b/src/AzureClient/AzureClient.cs @@ -5,20 +5,15 @@ using System; using System.Collections.Generic; -using System.IO; using System.Linq; using System.Net; using System.Threading.Tasks; using Microsoft.Azure.Quantum; -using Microsoft.Azure.Quantum.Client; using Microsoft.Azure.Quantum.Client.Models; using Microsoft.Extensions.Logging; -using Microsoft.Identity.Client; -using Microsoft.Identity.Client.Extensions.Msal; using Microsoft.Jupyter.Core; using Microsoft.Quantum.IQSharp.Common; using Microsoft.Quantum.Simulation.Common; -using Microsoft.Rest.Azure; namespace Microsoft.Quantum.IQSharp.AzureClient { @@ -30,19 +25,15 @@ public class AzureClient : IAzureClient private IEntryPointGenerator EntryPointGenerator { get; } private string ConnectionString { get; set; } = string.Empty; private AzureExecutionTarget? ActiveTarget { get; set; } - private AuthenticationResult? AuthenticationResult { get; set; } - private IQuantumClient? QuantumClient { get; set; } - private Azure.Quantum.IWorkspace? ActiveWorkspace { get; set; } + private IAzureWorkspace? ActiveWorkspace { get; set; } private string MostRecentJobId { get; set; } = string.Empty; - private IPage? AvailableProviders { get; set; } - private IEnumerable? AvailableTargets { get => AvailableProviders?.SelectMany(provider => provider.Targets); } - private IEnumerable? ValidExecutionTargets { get => AvailableTargets?.Where(target => AzureExecutionTarget.IsValid(target.Id)); } - private string ValidExecutionTargetsDisplayText - { - get => ValidExecutionTargets == null - ? "(no execution targets available)" - : string.Join(", ", ValidExecutionTargets.Select(target => target.Id)); - } + private IEnumerable? AvailableProviders { get; set; } + private IEnumerable? AvailableTargets => AvailableProviders?.SelectMany(provider => provider.Targets); + private IEnumerable? ValidExecutionTargets => AvailableTargets?.Where(target => AzureExecutionTarget.IsValid(target.Id)); + private string ValidExecutionTargetsDisplayText => + ValidExecutionTargets == null + ? "(no execution targets available)" + : string.Join(", ", ValidExecutionTargets.Select(target => target.Id)); public AzureClient( IExecutionEngine engine, @@ -77,85 +68,20 @@ public async Task ConnectAsync(IChannel channel, { ConnectionString = storageAccountConnectionString; - var azureEnvironmentEnvVarName = "AZURE_QUANTUM_ENV"; - var azureEnvironmentName = System.Environment.GetEnvironmentVariable(azureEnvironmentEnvVarName); - var azureEnvironment = AzureEnvironment.Create(azureEnvironmentName, subscriptionId); - - var msalApp = PublicClientApplicationBuilder - .Create(azureEnvironment.ClientId) - .WithAuthority(azureEnvironment.Authority) - .Build(); - - // Register the token cache for serialization - var cacheFileName = "aad.bin"; - var cacheDirectoryEnvVarName = "AZURE_QUANTUM_TOKEN_CACHE"; - var cacheDirectory = System.Environment.GetEnvironmentVariable(cacheDirectoryEnvVarName); - if (string.IsNullOrEmpty(cacheDirectory)) - { - cacheDirectory = Path.Join(System.Environment.GetFolderPath(System.Environment.SpecialFolder.UserProfile), ".azure-quantum"); - } - - var storageCreationProperties = new StorageCreationPropertiesBuilder(cacheFileName, cacheDirectory, azureEnvironment.ClientId).Build(); - var cacheHelper = await MsalCacheHelper.CreateAsync(storageCreationProperties); - cacheHelper.RegisterCache(msalApp.UserTokenCache); - - bool shouldShowLoginPrompt = refreshCredentials; - if (!shouldShowLoginPrompt) - { - try - { - var accounts = await msalApp.GetAccountsAsync(); - AuthenticationResult = await msalApp.AcquireTokenSilent( - azureEnvironment.Scopes, accounts.FirstOrDefault()).WithAuthority(msalApp.Authority).ExecuteAsync(); - } - catch (MsalUiRequiredException) - { - shouldShowLoginPrompt = true; - } - } - - if (shouldShowLoginPrompt) - { - AuthenticationResult = await msalApp.AcquireTokenWithDeviceCode( - azureEnvironment.Scopes, - deviceCodeResult => - { - channel.Stdout(deviceCodeResult.Message); - return Task.FromResult(0); - }).WithAuthority(msalApp.Authority).ExecuteAsync(); - } - - if (AuthenticationResult == null) + var azureEnvironment = AzureEnvironment.Create(subscriptionId); + ActiveWorkspace = await azureEnvironment.GetAuthenticatedWorkspaceAsync(channel, resourceGroupName, workspaceName, refreshCredentials); + if (ActiveWorkspace == null) { return AzureClientError.AuthenticationFailed.ToExecutionResult(); } - var credentials = new Rest.TokenCredentials(AuthenticationResult.AccessToken); - QuantumClient = new QuantumClient(credentials) - { - SubscriptionId = subscriptionId, - ResourceGroupName = resourceGroupName, - WorkspaceName = workspaceName, - BaseUri = azureEnvironment.BaseUri, - }; - ActiveWorkspace = new Azure.Quantum.Workspace( - QuantumClient.SubscriptionId, - QuantumClient.ResourceGroupName, - QuantumClient.WorkspaceName, - AuthenticationResult?.AccessToken, - azureEnvironment.BaseUri); - - try - { - AvailableProviders = await QuantumClient.Providers.GetStatusAsync(); - } - catch (Exception e) + AvailableProviders = await ActiveWorkspace.GetProvidersAsync(); + if (AvailableProviders == null) { - Logger?.LogError(e, $"Failed to download providers list from Azure Quantum workspace: {e.Message}"); return AzureClientError.WorkspaceNotFound.ToExecutionResult(); } - channel.Stdout($"Connected to Azure Quantum workspace {QuantumClient.WorkspaceName}."); + channel.Stdout($"Connected to Azure Quantum workspace {ActiveWorkspace.Name}."); return ValidExecutionTargets.ToExecutionResult(); } @@ -163,12 +89,12 @@ public async Task ConnectAsync(IChannel channel, /// public async Task GetConnectionStatusAsync(IChannel channel) { - if (QuantumClient == null || AvailableProviders == null) + if (ActiveWorkspace == null || AvailableProviders == null) { return AzureClientError.NotConnected.ToExecutionResult(); } - channel.Stdout($"Connected to Azure Quantum workspace {QuantumClient.WorkspaceName}."); + channel.Stdout($"Connected to Azure Quantum workspace {ActiveWorkspace.Name}."); return ValidExecutionTargets.ToExecutionResult(); } @@ -194,7 +120,7 @@ private async Task SubmitOrExecuteJobAsync(IChannel channel, Az return AzureClientError.NoOperationName.ToExecutionResult(); } - var machine = QuantumMachineFactory.CreateMachine(ActiveWorkspace, ActiveTarget.TargetId, ConnectionString); + var machine = ActiveWorkspace.CreateQuantumMachine(ActiveTarget.TargetId, ConnectionString); if (machine == null) { // We should never get here, since ActiveTarget should have already been validated at the time it was set. @@ -258,7 +184,7 @@ private async Task SubmitOrExecuteJobAsync(IChannel channel, Az // handle Jupyter kernel interrupt here and break out of this loop await Task.Delay(TimeSpan.FromSeconds(submissionContext.ExecutionPollingInterval)); if (cts.IsCancellationRequested) break; - cloudJob = await GetCloudJob(MostRecentJobId); + cloudJob = await ActiveWorkspace.GetJobAsync(MostRecentJobId); channel.Stdout($"[{DateTime.Now.ToLongTimeString()}] Current job status: {cloudJob?.Status ?? "Unknown"}"); } while (cloudJob == null || cloudJob.InProgress); @@ -351,7 +277,7 @@ public async Task GetJobResultAsync(IChannel channel, string jo jobId = MostRecentJobId; } - var job = await GetCloudJob(jobId); + var job = await ActiveWorkspace.GetJobAsync(jobId); if (job == null) { channel.Stderr($"Job ID {jobId} not found in current Azure Quantum workspace."); @@ -398,7 +324,7 @@ public async Task GetJobStatusAsync(IChannel channel, string jo jobId = MostRecentJobId; } - var job = await GetCloudJob(jobId); + var job = await ActiveWorkspace.GetJobAsync(jobId); if (job == null) { channel.Stderr($"Job ID {jobId} not found in current Azure Quantum workspace."); @@ -417,7 +343,7 @@ public async Task GetJobListAsync(IChannel channel) return AzureClientError.NotConnected.ToExecutionResult(); } - var jobs = await GetCloudJobs(); + var jobs = await ActiveWorkspace.ListJobsAsync(); if (jobs == null || jobs.Count() == 0) { channel.Stderr("No jobs found in current Azure Quantum workspace."); @@ -426,33 +352,5 @@ public async Task GetJobListAsync(IChannel channel) return jobs.ToExecutionResult(); } - - private async Task GetCloudJob(string jobId) - { - try - { - return await ActiveWorkspace.GetJobAsync(jobId); - } - catch (Exception e) - { - Logger?.LogError(e, $"Failed to retrieve the specified Azure Quantum job: {e.Message}"); - } - - return null; - } - - private async Task?> GetCloudJobs() - { - try - { - return await ActiveWorkspace.ListJobsAsync(); - } - catch (Exception e) - { - Logger?.LogError(e, $"Failed to retrieve the list of jobs from the Azure Quantum workspace: {e.Message}"); - } - - return null; - } } } diff --git a/src/AzureClient/AzureEnvironment.cs b/src/AzureClient/AzureEnvironment.cs index 20db5920ae..729a56ef59 100644 --- a/src/AzureClient/AzureEnvironment.cs +++ b/src/AzureClient/AzureEnvironment.cs @@ -5,9 +5,14 @@ using System; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Net; -using System.Text; +using System.Threading.Tasks; +using Microsoft.Azure.Quantum.Client; +using Microsoft.Identity.Client; +using Microsoft.Identity.Client.Extensions.Msal; +using Microsoft.Jupyter.Core; namespace Microsoft.Quantum.IQSharp.AzureClient { @@ -15,25 +20,31 @@ internal enum AzureEnvironmentType { Production, Canary, Dogfood }; internal class AzureEnvironment { - public string ClientId { get; private set; } = string.Empty; - public string Authority { get; private set; } = string.Empty; - public List Scopes { get; private set; } = new List(); - public Uri? BaseUri { get; private set; } + public AzureEnvironmentType Type { get; private set; } + + private string SubscriptionId { get; set; } = string.Empty; + private string ClientId { get; set; } = string.Empty; + private string Authority { get; set; } = string.Empty; + private List Scopes { get; set; } = new List(); + private Uri? BaseUri { get; set; } private AzureEnvironment() { } - public static AzureEnvironment Create(string environment, string subscriptionId) + public static AzureEnvironment Create(string subscriptionId) { - if (Enum.TryParse(environment, true, out AzureEnvironmentType environmentType)) + var azureEnvironmentEnvVarName = "AZURE_QUANTUM_ENV"; + var azureEnvironmentName = System.Environment.GetEnvironmentVariable(azureEnvironmentEnvVarName); + + if (Enum.TryParse(azureEnvironmentName, true, out AzureEnvironmentType environmentType)) { switch (environmentType) { case AzureEnvironmentType.Production: - return Production(); + return Production(subscriptionId); case AzureEnvironmentType.Canary: - return Canary(); + return Canary(subscriptionId); case AzureEnvironmentType.Dogfood: return Dogfood(subscriptionId); default: @@ -41,30 +52,104 @@ public static AzureEnvironment Create(string environment, string subscriptionId) } } - return Production(); + return Production(subscriptionId); + } + + public async Task GetAuthenticatedWorkspaceAsync(IChannel channel, string resourceGroupName, string workspaceName, bool refreshCredentials) + { + // Find the token cache folder + var cacheDirectoryEnvVarName = "AZURE_QUANTUM_TOKEN_CACHE"; + var cacheDirectory = System.Environment.GetEnvironmentVariable(cacheDirectoryEnvVarName); + if (string.IsNullOrEmpty(cacheDirectory)) + { + cacheDirectory = Path.Join(System.Environment.GetFolderPath(System.Environment.SpecialFolder.UserProfile), ".azure-quantum"); + } + + // Register the token cache for serialization + var cacheFileName = "aad.bin"; + var storageCreationProperties = new StorageCreationPropertiesBuilder(cacheFileName, cacheDirectory, ClientId).Build(); + var cacheHelper = await MsalCacheHelper.CreateAsync(storageCreationProperties); + var msalApp = PublicClientApplicationBuilder.Create(ClientId).WithAuthority(Authority).Build(); + cacheHelper.RegisterCache(msalApp.UserTokenCache); + + // Perform the authentication + bool shouldShowLoginPrompt = refreshCredentials; + AuthenticationResult? authenticationResult = null; + if (!shouldShowLoginPrompt) + { + try + { + var accounts = await msalApp.GetAccountsAsync(); + authenticationResult = await msalApp.AcquireTokenSilent( + Scopes, accounts.FirstOrDefault()).WithAuthority(msalApp.Authority).ExecuteAsync(); + } + catch (MsalUiRequiredException) + { + shouldShowLoginPrompt = true; + } + } + + if (shouldShowLoginPrompt) + { + authenticationResult = await msalApp.AcquireTokenWithDeviceCode( + Scopes, + deviceCodeResult => + { + channel.Stdout(deviceCodeResult.Message); + return Task.FromResult(0); + }).WithAuthority(msalApp.Authority).ExecuteAsync(); + } + + if (authenticationResult == null) + { + return null; + } + + // Construct and return the AzureWorkspace object + var credentials = new Rest.TokenCredentials(authenticationResult.AccessToken); + var azureQuantumClient = new QuantumClient(credentials) + { + SubscriptionId = SubscriptionId, + ResourceGroupName = resourceGroupName, + WorkspaceName = workspaceName, + BaseUri = BaseUri, + }; + var azureQuantumWorkspace = new Azure.Quantum.Workspace( + azureQuantumClient.SubscriptionId, + azureQuantumClient.ResourceGroupName, + azureQuantumClient.WorkspaceName, + authenticationResult?.AccessToken, + BaseUri); + + return new AzureWorkspace(azureQuantumClient, azureQuantumWorkspace); } - private static AzureEnvironment Production() => + private static AzureEnvironment Production(string subscriptionId) => new AzureEnvironment() { + Type = AzureEnvironmentType.Production, ClientId = "84ba0947-6c53-4dd2-9ca9-b3694761521b", // QDK client ID Authority = "https://login.microsoftonline.com/common", Scopes = new List() { "https://quantum.microsoft.com/Jobs.ReadWrite" }, BaseUri = new Uri("https://app-jobscheduler-prod.azurewebsites.net/"), + SubscriptionId = subscriptionId, }; private static AzureEnvironment Dogfood(string subscriptionId) => new AzureEnvironment() { + Type = AzureEnvironmentType.Dogfood, ClientId = "46a998aa-43d0-4281-9cbb-5709a507ac36", // QDK dogfood client ID Authority = GetDogfoodAuthority(subscriptionId), Scopes = new List() { "api://dogfood.azure-quantum/Jobs.ReadWrite" }, BaseUri = new Uri("https://app-jobscheduler-test.azurewebsites.net/"), + SubscriptionId = subscriptionId, }; - private static AzureEnvironment Canary() + private static AzureEnvironment Canary(string subscriptionId) { - var canary = Production(); + var canary = Production(subscriptionId); + canary.Type = AzureEnvironmentType.Canary; canary.BaseUri = new Uri("https://app-jobs-canarysouthcentralus.azurewebsites.net/"); return canary; } diff --git a/src/AzureClient/AzureWorkspace.cs b/src/AzureClient/AzureWorkspace.cs new file mode 100644 index 0000000000..54fb782abc --- /dev/null +++ b/src/AzureClient/AzureWorkspace.cs @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.Azure.Quantum; +using Microsoft.Azure.Quantum.Client; +using Microsoft.Azure.Quantum.Client.Models; +using Microsoft.Extensions.Logging; +using Microsoft.Quantum.Runtime; + +namespace Microsoft.Quantum.IQSharp.AzureClient +{ + internal class AzureWorkspace : IAzureWorkspace + { + public string? Name => AzureQuantumClient?.WorkspaceName; + + private Azure.Quantum.IWorkspace AzureQuantumWorkspace { get; set; } + private QuantumClient AzureQuantumClient { get; set; } + private ILogger Logger { get; } = new LoggerFactory().CreateLogger(); + + public AzureWorkspace(QuantumClient azureQuantumClient, Azure.Quantum.Workspace azureQuantumWorkspace) + { + AzureQuantumClient = azureQuantumClient; + AzureQuantumWorkspace = azureQuantumWorkspace; + } + + public async Task?> GetProvidersAsync() + { + try + { + return await AzureQuantumClient.Providers.GetStatusAsync(); + } + catch (Exception e) + { + Logger.LogError(e, $"Failed to retrieve the providers list from the Azure Quantum workspace: {e.Message}"); + } + + return null; + } + + public async Task GetJobAsync(string jobId) + { + try + { + return await AzureQuantumWorkspace.GetJobAsync(jobId); + } + catch (Exception e) + { + Logger.LogError(e, $"Failed to retrieve the specified Azure Quantum job: {e.Message}"); + } + + return null; + } + + public async Task?> ListJobsAsync() + { + try + { + return await AzureQuantumWorkspace.ListJobsAsync(); + } + catch (Exception e) + { + Logger.LogError(e, $"Failed to retrieve the list of jobs from the Azure Quantum workspace: {e.Message}"); + } + + return null; + } + + public IQuantumMachine? CreateQuantumMachine(string targetId, string storageAccountConnectionString) + { + return QuantumMachineFactory.CreateMachine(AzureQuantumWorkspace, targetId, storageAccountConnectionString); + } + } +} diff --git a/src/AzureClient/IAzureWorkspace.cs b/src/AzureClient/IAzureWorkspace.cs new file mode 100644 index 0000000000..a08152b3b7 --- /dev/null +++ b/src/AzureClient/IAzureWorkspace.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#nullable enable + +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.Azure.Quantum; +using Microsoft.Azure.Quantum.Client.Models; +using Microsoft.Quantum.Runtime; + +namespace Microsoft.Quantum.IQSharp.AzureClient +{ + internal interface IAzureWorkspace + { + public string Name { get; } + + public Task> GetProvidersAsync(); + public Task GetJobAsync(string jobId); + public Task> ListJobsAsync(); + public IQuantumMachine? CreateQuantumMachine(string targetId, string storageAccountConnectionString); + } +} \ No newline at end of file