diff --git a/src/Common/Constants.ts b/src/Common/Constants.ts index 46b890d29..355b9760b 100644 --- a/src/Common/Constants.ts +++ b/src/Common/Constants.ts @@ -177,6 +177,7 @@ export class HttpHeaders { public static activityId: string = "x-ms-activity-id"; public static apiType: string = "x-ms-cosmos-apitype"; public static authorization: string = "authorization"; + public static graphAuthorization: string = "graph-authorization"; public static collectionIndexTransformationProgress: string = "x-ms-documentdb-collection-index-transformation-progress"; public static continuation: string = "x-ms-continuation"; diff --git a/src/HostedExplorer.tsx b/src/HostedExplorer.tsx index 04abb1dd5..9593ad092 100644 --- a/src/HostedExplorer.tsx +++ b/src/HostedExplorer.tsx @@ -7,9 +7,6 @@ import "../less/hostedexplorer.less"; import { AuthType } from "./AuthType"; import { DatabaseAccount } from "./Contracts/DataModels"; import "./Explorer/Menus/NavBar/MeControlComponent.less"; -import { useAADAuth } from "./hooks/useAADAuth"; -import { useConfig } from "./hooks/useConfig"; -import { useTokenMetadata } from "./hooks/usePortalAccessToken"; import { HostedExplorerChildFrame } from "./HostedExplorerChildFrame"; import { AccountSwitcher } from "./Platform/Hosted/Components/AccountSwitcher"; import { ConnectExplorer } from "./Platform/Hosted/Components/ConnectExplorer"; @@ -20,6 +17,9 @@ import { SignInButton } from "./Platform/Hosted/Components/SignInButton"; import "./Platform/Hosted/ConnectScreen.less"; import { extractMasterKeyfromConnectionString } from "./Platform/Hosted/HostedUtils"; import "./Shared/appInsights"; +import { useAADAuth } from "./hooks/useAADAuth"; +import { useConfig } from "./hooks/useConfig"; +import { useTokenMetadata } from "./hooks/usePortalAccessToken"; initializeIcons(); @@ -51,6 +51,7 @@ const App: React.FunctionComponent = () => { authType: AuthType.AAD, databaseAccount, authorizationToken: armToken, + graphAuthorizationToken: graphToken }; } else if (authType === AuthType.EncryptedToken) { frameWindow.hostedConfig = { diff --git a/src/HostedExplorerChildFrame.ts b/src/HostedExplorerChildFrame.ts index 2cff6c862..8fdce2b3c 100644 --- a/src/HostedExplorerChildFrame.ts +++ b/src/HostedExplorerChildFrame.ts @@ -10,6 +10,7 @@ export interface AAD { authType: AuthType.AAD; databaseAccount: DatabaseAccount; authorizationToken: string; + graphAuthorizationToken: string; } export interface ConnectionString { diff --git a/src/UserContext.ts b/src/UserContext.ts index b265ae80a..b282646d1 100644 --- a/src/UserContext.ts +++ b/src/UserContext.ts @@ -79,6 +79,7 @@ interface UserContext { collectionCreationDefaults: CollectionCreationDefaults; sampleDataConnectionInfo?: ParsedResourceTokenConnectionString; readonly vcoreMongoConnectionParams?: VCoreMongoConnectionParams; + readonly accountRestrictedFromUser?: boolean; } export type ApiType = "SQL" | "Mongo" | "Gremlin" | "Tables" | "Cassandra" | "Postgres" | "VCoreMongo"; @@ -171,3 +172,4 @@ function apiType(account: DatabaseAccount | undefined): ApiType { } export { updateUserContext, userContext }; + diff --git a/src/Utils/AuthorizationUtils.ts b/src/Utils/AuthorizationUtils.ts index 7fe1709c0..91cfe8abf 100644 --- a/src/Utils/AuthorizationUtils.ts +++ b/src/Utils/AuthorizationUtils.ts @@ -60,3 +60,27 @@ export function getMsalInstance() { const msalInstance = new msal.PublicClientApplication(config); return msalInstance; } + +export async function isAccountRestrictedFromUser(accountName: string, graphToken: string): Promise { + const checkUserAccessUrl: string = "https://localhost:12901/api/guest/accountrestrictions/accountrestrictedfromuser"; + // const authorizationHeader = getAuthorizationHeader(); + try { + const response: Response = await fetch(checkUserAccessUrl, { + method: "POST", + body: JSON.stringify({ + accountName + }), + headers: { + // [authorizationHeader.header]: authorizationHeader.token, + [Constants.HttpHeaders.graphAuthorization]: graphToken, + [Constants.HttpHeaders.contentType]: "application/json", + } + }); + + const responseText: string = await response.text(); + return responseText.toLowerCase() === "true"; + } catch (e) { + console.log(e); + throw new Error(e); + } +} diff --git a/src/hooks/useKnockoutExplorer.ts b/src/hooks/useKnockoutExplorer.ts index f515eb2d3..4f9205cea 100644 --- a/src/hooks/useKnockoutExplorer.ts +++ b/src/hooks/useKnockoutExplorer.ts @@ -36,7 +36,7 @@ import { extractFeatures } from "../Platform/Hosted/extractFeatures"; import { CollectionCreation } from "../Shared/Constants"; import { DefaultExperienceUtility } from "../Shared/DefaultExperienceUtility"; import { Node, PortalEnv, updateUserContext, userContext } from "../UserContext"; -import { getAuthorizationHeader, getMsalInstance } from "../Utils/AuthorizationUtils"; +import { getAuthorizationHeader, getMsalInstance, isAccountRestrictedFromUser } from "../Utils/AuthorizationUtils"; import { isInvalidParentFrameOrigin, shouldProcessMessage } from "../Utils/MessageValidation"; import { listKeys } from "../Utils/arm/generatedClients/cosmos/databaseAccounts"; import { DatabaseAccountListKeysResult } from "../Utils/arm/generatedClients/cosmos/types"; @@ -227,9 +227,11 @@ async function configureHosted(): Promise { async function configureHostedWithAAD(config: AAD): Promise { // TODO: Refactor. updateUserContext needs to be called twice because listKeys below depends on userContext.authorizationToken + const accountRestrictedFromUser: boolean = await isAccountRestrictedFromUser(config.databaseAccount.name, config.graphAuthorizationToken); updateUserContext({ authType: AuthType.AAD, authorizationToken: `Bearer ${config.authorizationToken}`, + accountRestrictedFromUser }); const account = config.databaseAccount; const accountResourceId = account.id;