Skip to content

Commit

Permalink
Use an EventBufferer to ensure only one event across PCAs (#228400)
Browse files Browse the repository at this point in the history
  • Loading branch information
TylerLeonhardt committed Sep 12, 2024
1 parent 9148684 commit db2a1df
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 12 deletions.
107 changes: 107 additions & 0 deletions extensions/microsoft-authentication/src/common/event.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { Event } from 'vscode';

/**
* The EventBufferer is useful in situations in which you want
* to delay firing your events during some code.
* You can wrap that code and be sure that the event will not
* be fired during that wrap.
*
* ```
* const emitter: Emitter;
* const delayer = new EventDelayer();
* const delayedEvent = delayer.wrapEvent(emitter.event);
*
* delayedEvent(console.log);
*
* delayer.bufferEvents(() => {
* emitter.fire(); // event will not be fired yet
* });
*
* // event will only be fired at this point
* ```
*/
export class EventBufferer {

private data: { buffers: Function[] }[] = [];

wrapEvent<T>(event: Event<T>): Event<T>;
wrapEvent<T>(event: Event<T>, reduce: (last: T | undefined, event: T) => T): Event<T>;
wrapEvent<T, O>(event: Event<T>, reduce: (last: O | undefined, event: T) => O, initial: O): Event<O>;
wrapEvent<T, O>(event: Event<T>, reduce?: (last: T | O | undefined, event: T) => T | O, initial?: O): Event<O | T> {
return (listener, thisArgs?, disposables?) => {
return event(i => {
const data = this.data[this.data.length - 1];

// Non-reduce scenario
if (!reduce) {
// Buffering case
if (data) {
data.buffers.push(() => listener.call(thisArgs, i));
} else {
// Not buffering case
listener.call(thisArgs, i);
}
return;
}

// Reduce scenario
const reduceData = data as typeof data & {
/**
* The accumulated items that will be reduced.
*/
items?: T[];
/**
* The reduced result cached to be shared with other listeners.
*/
reducedResult?: T | O;
};

// Not buffering case
if (!reduceData) {
// TODO: Is there a way to cache this reduce call for all listeners?
listener.call(thisArgs, reduce(initial, i));
return;
}

// Buffering case
reduceData.items ??= [];
reduceData.items.push(i);
if (reduceData.buffers.length === 0) {
// Include a single buffered function that will reduce all events when we're done buffering events
data.buffers.push(() => {
// cache the reduced result so that the value can be shared across all listeners
reduceData.reducedResult ??= initial
? reduceData.items!.reduce(reduce as (last: O | undefined, event: T) => O, initial)
: reduceData.items!.reduce(reduce as (last: T | undefined, event: T) => T);
listener.call(thisArgs, reduceData.reducedResult);
});
}
}, undefined, disposables);
};
}

bufferEvents<R = void>(fn: () => R): R {
const data = { buffers: new Array<Function>() };
this.data.push(data);
const r = fn();
this.data.pop();
data.buffers.forEach(flush => flush());
return r;
}

async bufferEventsAsync<R = void>(fn: () => Promise<R>): Promise<R> {
const data = { buffers: new Array<Function>() };
this.data.push(data);
try {
const r = await fn();
return r;
} finally {
this.data.pop();
data.buffers.forEach(flush => flush());
}
}
}
51 changes: 39 additions & 12 deletions extensions/microsoft-authentication/src/node/authProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { ICachedPublicClientApplication } from '../common/publicClientCache';
import { MicrosoftAccountType, MicrosoftAuthenticationTelemetryReporter } from '../common/telemetryReporter';
import { loopbackTemplate } from './loopbackTemplate';
import { ScopeData } from '../common/scopeData';
import { EventBufferer } from '../common/event';

const redirectUri = 'https://vscode.dev/redirect';
const MSA_TID = '9188040d-6c67-4c5b-b112-36a304b66dad';
Expand All @@ -21,6 +22,7 @@ export class MsalAuthProvider implements AuthenticationProvider {

private readonly _disposables: { dispose(): void }[];
private readonly _publicClientManager: CachedPublicClientApplicationManager;
private readonly _eventBufferer = new EventBufferer();

/**
* Event to signal a change in authentication sessions for this provider.
Expand Down Expand Up @@ -49,15 +51,37 @@ export class MsalAuthProvider implements AuthenticationProvider {
) {
this._disposables = context.subscriptions;
this._publicClientManager = new CachedPublicClientApplicationManager(context.globalState, context.secrets, this._logger);
const accountChangeEvent = this._eventBufferer.wrapEvent(
this._publicClientManager.onDidAccountsChange,
(last, newEvent) => {
if (!last) {
return newEvent;
}
const mergedEvent = {
added: [...(last.added ?? []), ...(newEvent.added ?? [])],
deleted: [...(last.deleted ?? []), ...(newEvent.deleted ?? [])],
changed: [...(last.changed ?? []), ...(newEvent.changed ?? [])]
};

const dedupedEvent = {
added: Array.from(new Map(mergedEvent.added.map(item => [item.username, item])).values()),
deleted: Array.from(new Map(mergedEvent.deleted.map(item => [item.username, item])).values()),
changed: Array.from(new Map(mergedEvent.changed.map(item => [item.username, item])).values())
};

return dedupedEvent;
},
{ added: new Array<AccountInfo>(), deleted: new Array<AccountInfo>(), changed: new Array<AccountInfo>() }
)(e => this._handleAccountChange(e));
this._disposables.push(
this._onDidChangeSessionsEmitter,
this._publicClientManager,
this._publicClientManager.onDidAccountsChange(e => this._handleAccountChange(e))
accountChangeEvent
);
}

async initialize(): Promise<void> {
await this._publicClientManager.initialize();
await this._eventBufferer.bufferEventsAsync(() => this._publicClientManager.initialize());

// Send telemetry for existing accounts
for (const cachedPca of this._publicClientManager.getAll()) {
Expand All @@ -77,6 +101,7 @@ export class MsalAuthProvider implements AuthenticationProvider {
* @param param0 Event that contains the added and removed accounts
*/
private _handleAccountChange({ added, changed, deleted }: { added: AccountInfo[]; changed: AccountInfo[]; deleted: AccountInfo[] }) {
this._logger.debug(`[_handleAccountChange] added: ${added.length}, changed: ${changed.length}, deleted: ${deleted.length}`);
this._onDidChangeSessionsEmitter.fire({
added: added.map(this.sessionFromAccountInfo),
changed: changed.map(this.sessionFromAccountInfo),
Expand Down Expand Up @@ -225,17 +250,19 @@ export class MsalAuthProvider implements AuthenticationProvider {
? cachedPca.accounts.filter(a => a.homeAccountId === accountFilter.id)
: cachedPca.accounts;
const sessions: AuthenticationSession[] = [];
for (const account of accounts) {
try {
const result = await cachedPca.acquireTokenSilent({ account, scopes: scopesToSend, redirectUri });
sessions.push(this.sessionFromAuthenticationResult(result, originalScopes));
} catch (e) {
// If we can't get a token silently, the account is probably in a bad state so we should skip it
// MSAL will log this already, so we don't need to log it again
continue;
return this._eventBufferer.bufferEventsAsync(async () => {
for (const account of accounts) {
try {
const result = await cachedPca.acquireTokenSilent({ account, scopes: scopesToSend, redirectUri });
sessions.push(this.sessionFromAuthenticationResult(result, originalScopes));
} catch (e) {
// If we can't get a token silently, the account is probably in a bad state so we should skip it
// MSAL will log this already, so we don't need to log it again
continue;
}
}
}
return sessions;
return sessions;
});
}

private sessionFromAuthenticationResult(result: AuthenticationResult, scopes: readonly string[]): AuthenticationSession & { idToken: string } {
Expand Down
14 changes: 14 additions & 0 deletions extensions/microsoft-authentication/src/node/publicClientCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,24 @@ export class CachedPublicClientApplicationManager implements ICachedPublicClient
}

const results = await Promise.allSettled(promises);
let pcasChanged = false;
for (const result of results) {
if (result.status === 'rejected') {
this._logger.error('[initialize] Error getting PCA:', result.reason);
} else {
if (!result.value.accounts.length) {
pcasChanged = true;
const pcaKey = JSON.stringify({ clientId: result.value.clientId, authority: result.value.authority });
this._pcaDisposables.get(pcaKey)?.dispose();
this._pcaDisposables.delete(pcaKey);
this._pcas.delete(pcaKey);
this._logger.debug(`[initialize] [${result.value.clientId}] [${result.value.authority}] PCA disposed because it's empty.`);
}
}
}
if (pcasChanged) {
await this._storePublicClientApplications();
}
this._logger.debug('[initialize] PublicClientApplicationManager initialized');
}

Expand Down Expand Up @@ -106,6 +119,7 @@ export class CachedPublicClientApplicationManager implements ICachedPublicClient
// The PCA has no more accounts, so we can dispose it so we're not keeping it
// around forever.
disposable.dispose();
this._pcaDisposables.delete(pcasKey);
this._pcas.delete(pcasKey);
this._logger.debug(`[_doCreatePublicClientApplication] [${clientId}] [${authority}] PCA disposed. Firing off storing of PCAs...`);
void this._storePublicClientApplications();
Expand Down

0 comments on commit db2a1df

Please sign in to comment.