From 66a1ca321d6681b3fd5bd4402d490b2e1a0a9223 Mon Sep 17 00:00:00 2001 From: Senthamil Sindhu Date: Thu, 21 Nov 2024 16:01:49 -0800 Subject: [PATCH] Add changes for ARM/AAD Token renewal --- src/Common/CosmosClient.ts | 10 +++- src/UserContext.ts | 1 + src/hooks/useAADAuth.ts | 3 +- src/hooks/useDatabaseAccounts.tsx | 89 +++++++++++++++++++++++++++---- src/hooks/useSubscriptions.tsx | 16 ++++-- 5 files changed, 103 insertions(+), 16 deletions(-) diff --git a/src/Common/CosmosClient.ts b/src/Common/CosmosClient.ts index f7a4fbfc5..64640dacc 100644 --- a/src/Common/CosmosClient.ts +++ b/src/Common/CosmosClient.ts @@ -8,11 +8,12 @@ import { AuthType } from "../AuthType"; import { BackendApi, PriorityLevel } from "../Common/Constants"; import * as Logger from "../Common/Logger"; import { Platform, configContext } from "../ConfigContext"; -import { userContext } from "../UserContext"; +import { updateUserContext, userContext } from "../UserContext"; import { logConsoleError } from "../Utils/NotificationConsoleUtils"; import * as PriorityBasedExecutionUtils from "../Utils/PriorityBasedExecutionUtils"; import { EmulatorMasterKey, HttpHeaders } from "./Constants"; import { getErrorMessage } from "./ErrorHandlingUtils"; +import { runCommand } from "hooks/useDatabaseAccounts"; const _global = typeof self === "undefined" ? window : self; @@ -32,6 +33,7 @@ export const tokenProvider = async (requestInfo: Cosmos.RequestInfo) => { return null; } const AUTH_PREFIX = `type=aad&ver=1.0&sig=`; + console.log("AAD Token ", userContext.aadToken); const authorizationToken = `${AUTH_PREFIX}${userContext.aadToken}`; return authorizationToken; } @@ -209,10 +211,14 @@ export function client(): Cosmos.CosmosClient { _defaultHeaders["x-ms-cosmos-priority-level"] = PriorityLevel.Default; } + const wrappedTokenProvider = async (requestInfo: Cosmos.RequestInfo) => { + return await runCommand(tokenProvider, requestInfo); + }; + const options: Cosmos.CosmosClientOptions = { endpoint: endpoint() || "https://cosmos.azure.com", // CosmosClient gets upset if we pass a bad URL. This should never actually get called key: userContext.dataPlaneRbacEnabled ? "" : userContext.masterKey, - tokenProvider, + tokenProvider: wrappedTokenProvider, userAgentSuffix: "Azure Portal", defaultHeaders: _defaultHeaders, connectionPolicy: { diff --git a/src/UserContext.ts b/src/UserContext.ts index 955452d3a..e9f59fbd1 100644 --- a/src/UserContext.ts +++ b/src/UserContext.ts @@ -81,6 +81,7 @@ export interface UserContext { readonly endpoint?: string; readonly aadToken?: string; readonly accessToken?: string; + readonly armToken?: string; readonly authorizationToken?: string; readonly resourceToken?: string; readonly subscriptionType?: SubscriptionType; diff --git a/src/hooks/useAADAuth.ts b/src/hooks/useAADAuth.ts index c20f953f7..dadc6d5db 100644 --- a/src/hooks/useAADAuth.ts +++ b/src/hooks/useAADAuth.ts @@ -3,6 +3,7 @@ import { useBoolean } from "@fluentui/react-hooks"; import * as React from "react"; import { configContext } from "../ConfigContext"; import { acquireTokenWithMsal, getMsalInstance } from "../Utils/AuthorizationUtils"; +import { updateUserContext } from "UserContext"; const msalInstance = await getMsalInstance(); @@ -79,7 +80,7 @@ export function useAADAuth(): ReturnType { authority: `${configContext.AAD_ENDPOINT}${tenantId}`, scopes: [`${configContext.ARM_ENDPOINT}/.default`], }); - + updateUserContext({ armToken: armToken }); setArmToken(armToken); setAuthFailure(null); } catch (error) { diff --git a/src/hooks/useDatabaseAccounts.tsx b/src/hooks/useDatabaseAccounts.tsx index f517b2e30..040b336b5 100644 --- a/src/hooks/useDatabaseAccounts.tsx +++ b/src/hooks/useDatabaseAccounts.tsx @@ -1,6 +1,8 @@ import { HttpHeaders } from "Common/Constants"; import { QueryRequestOptions, QueryResponse } from "Contracts/AzureResourceGraph"; import useSWR from "swr"; +import { updateUserContext, userContext } from "UserContext"; +import { acquireTokenWithMsal, getMsalInstance } from "Utils/AuthorizationUtils"; import { configContext } from "../ConfigContext"; import { DatabaseAccount } from "../Contracts/DataModels"; /* eslint-disable @typescript-eslint/no-explicit-any */ @@ -33,12 +35,9 @@ export async function fetchDatabaseAccounts(subscriptionId: string, accessToken: return accounts.sort((a, b) => a.name.localeCompare(b.name)); } -export async function fetchDatabaseAccountsFromGraph( - subscriptionId: string, - accessToken: string, -): Promise { +export async function fetchDatabaseAccountsFromGraph(subscriptionId: string): Promise { const headers = new Headers(); - const bearer = `Bearer ${accessToken}`; + const bearer = `Bearer ${userContext.armToken}`; headers.append("Authorization", bearer); headers.append(HttpHeaders.contentType, "application/json"); @@ -46,8 +45,9 @@ export async function fetchDatabaseAccountsFromGraph( const apiVersion = "2021-03-01"; const managementResourceGraphAPIURL = `${configContext.ARM_ENDPOINT}providers/Microsoft.ResourceGraph/resources?api-version=${apiVersion}`; - const databaseAccounts: DatabaseAccount[] = []; + let databaseAccounts: DatabaseAccount[] = []; let skipToken: string; + console.log("Old ARM Token - fetchDatabaseAccountsFromGraph function", userContext.armToken); do { const body = { query: databaseAccountsQuery, @@ -85,10 +85,81 @@ export async function fetchDatabaseAccountsFromGraph( return databaseAccounts.sort((a, b) => a.name.localeCompare(b.name)); } -export function useDatabaseAccounts(subscriptionId: string, armToken: string): DatabaseAccount[] | undefined { +export function useDatabaseAccounts(subscriptionId: string): DatabaseAccount[] | undefined { const { data } = useSWR( - () => (armToken && subscriptionId ? ["databaseAccounts", subscriptionId, armToken] : undefined), - (_, subscriptionId, armToken) => fetchDatabaseAccountsFromGraph(subscriptionId, armToken), + () => (subscriptionId ? ["databaseAccounts", subscriptionId] : undefined), + (_, subscriptionId) => runCommand(fetchDatabaseAccountsFromGraph, subscriptionId), ); return data; } + +// Define the types for your responses +interface DatabaseAccount { + name: string; + id: string; + // Add other relevant fields as per your use case +} + +interface QueryRequestOptions { + $top?: number; + $skipToken?: string; + $allowPartialScopes?: boolean; +} + +// Define the configuration context and headers if not already defined +const configContext = { + ARM_ENDPOINT: "https://management.azure.com/", + AAD_ENDPOINT: "https://login.microsoftonline.com/", +}; + +interface QueryResponse { + data?: any[]; + $skipToken?: string; +} + +export async function runCommand(fn: (...args: any[]) => Promise, ...args: any[]): Promise { + try { + // Attempt to execute the function passed as an argument + const result = await fn(...args); + console.log("Successfully executed function:", fn.name, result); + return result; + } catch (error) { + // Handle any error that is thrown during the execution of the function + if (error) { + console.log("Creating new token"); + const msalInstance = await getMsalInstance(); + + const cachedAccount = msalInstance.getAllAccounts()?.[0]; + const cachedTenantId = localStorage.getItem("cachedTenantId"); + + msalInstance.setActiveAccount(cachedAccount); + + // TODO: Add condition to check if the ARM token needs to be renewed, then we need to run the code below for creating the ARM token + + console.log("Creating new ARM token"); + const newAccessToken = await acquireTokenWithMsal(msalInstance, { + authority: `${configContext.AAD_ENDPOINT}${cachedTenantId}`, + scopes: [`${configContext.ARM_ENDPOINT}/.default`], + }); + updateUserContext({ armToken: newAccessToken }); + + // TODO: add condition to check if AAD token needs to be renewed (i.e) Token provider has failed with expired AAD token and create a new AAD Token using the below code + + // const hrefEndpoint = new URL(userContext.databaseAccount.properties.documentEndpoint).href.replace(/\/$/, "/.default"); + // console.log('Creating new AAD token'); + // let aadToken = await acquireTokenWithMsal(msalInstance, { + // forceRefresh: true, + // authority: `${configContext.AAD_ENDPOINT}${cachedTenantId}`, + // scopes: [hrefEndpoint], + // }); + // updateUserContext({aadToken: aadToken}); + + //console.log('Latest AAD Token', fn.name, userContext.aadToken); + const result = await fn(...args); + return result; + } else { + console.error("An error occurred:", error.message); + throw new error(); + } + } +} diff --git a/src/hooks/useSubscriptions.tsx b/src/hooks/useSubscriptions.tsx index ca80a87f5..8181c70d6 100644 --- a/src/hooks/useSubscriptions.tsx +++ b/src/hooks/useSubscriptions.tsx @@ -3,6 +3,8 @@ import { QueryRequestOptions, QueryResponse } from "Contracts/AzureResourceGraph import useSWR from "swr"; import { configContext } from "../ConfigContext"; import { Subscription } from "../Contracts/DataModels"; +import { runCommand } from "hooks/useDatabaseAccounts"; +import { userContext } from "UserContext"; /* eslint-disable @typescript-eslint/no-explicit-any */ interface SubscriptionListResult { @@ -35,9 +37,9 @@ export async function fetchSubscriptions(accessToken: string): Promise a.displayName.localeCompare(b.displayName)); } -export async function fetchSubscriptionsFromGraph(accessToken: string): Promise { +export async function fetchSubscriptionsFromGraph(): Promise { const headers = new Headers(); - const bearer = `Bearer ${accessToken}`; + const bearer = `Bearer ${userContext.armToken}`; headers.append("Authorization", bearer); headers.append(HttpHeaders.contentType, "application/json"); @@ -48,6 +50,7 @@ export async function fetchSubscriptionsFromGraph(accessToken: string): Promise< const subscriptions: Subscription[] = []; let skipToken: string; + console.log("Old ARM Token fetchSubscriptionsFromGraph fn", userContext.armToken); do { const body = { query: subscriptionsQuery, @@ -86,9 +89,14 @@ export async function fetchSubscriptionsFromGraph(accessToken: string): Promise< } export function useSubscriptions(armToken: string): Subscription[] | undefined { - const { data } = useSWR( + const { data, error } = useSWR( () => (armToken ? ["subscriptions", armToken] : undefined), - (_, armToken) => fetchSubscriptionsFromGraph(armToken), + (_) => runCommand(fetchSubscriptionsFromGraph), ); + + if (error) { + console.error("Error fetching subscriptions:", error); + } + return data; }