diff --git a/src/execution/__tests__/stream-test.ts b/src/execution/__tests__/stream-test.ts index aed5211ae1..ea1175acb1 100644 --- a/src/execution/__tests__/stream-test.ts +++ b/src/execution/__tests__/stream-test.ts @@ -450,6 +450,56 @@ describe('Execute: stream directive', () => { }, ]); }); + it('Can stream a field that returns a list of promises with nested promises', async () => { + const document = parse(` + query { + friendList @stream(initialCount: 2) { + name + id + } + } + `); + const result = await complete(document, { + friendList: () => + friends.map((f) => + Promise.resolve({ + name: Promise.resolve(f.name), + id: Promise.resolve(f.id), + }), + ), + }); + expectJSON(result).toDeepEqual([ + { + data: { + friendList: [ + { + name: 'Luke', + id: '1', + }, + { + name: 'Han', + id: '2', + }, + ], + }, + hasNext: true, + }, + { + incremental: [ + { + items: [ + { + name: 'Leia', + id: '3', + }, + ], + path: ['friendList', 2], + }, + ], + hasNext: false, + }, + ]); + }); it('Handles rejections in a field that returns a list of promises before initialCount is reached', async () => { const document = parse(` query { @@ -531,11 +581,6 @@ describe('Execute: stream directive', () => { }, ], }, - ], - hasNext: true, - }, - { - incremental: [ { items: [{ name: 'Leia', id: '3' }], path: ['friendList', 2], @@ -984,11 +1029,6 @@ describe('Execute: stream directive', () => { }, ], }, - ], - hasNext: true, - }, - { - incremental: [ { items: [{ nonNullName: 'Han' }], path: ['friendList', 2], diff --git a/src/execution/execute.ts b/src/execution/execute.ts index 052cb8da25..99e8ec6c9c 100644 --- a/src/execution/execute.ts +++ b/src/execution/execute.ts @@ -1798,6 +1798,78 @@ function executeDeferredFragment( asyncPayloadRecord.addData(promiseOrData); } +async function completedItemsFromPromisedItem( + exeContext: ExecutionContext, + itemType: GraphQLOutputType, + fieldNodes: ReadonlyArray, + info: GraphQLResolveInfo, + path: Path, + itemPath: Path, + item: Promise, + asyncPayloadRecord: AsyncPayloadRecord, +): Promise<[unknown] | null> { + try { + try { + const resolved = await item; + let completedItem = completeValue( + exeContext, + itemType, + fieldNodes, + info, + itemPath, + resolved, + asyncPayloadRecord, + ); + if (isPromise(completedItem)) { + completedItem = await completedItem; + } + return [completedItem]; + } catch (rawError) { + const error = locatedError(rawError, fieldNodes, pathToArray(itemPath)); + const handledError = handleFieldError( + error, + itemType, + asyncPayloadRecord.errors, + ); + filterSubsequentPayloads(exeContext, itemPath, asyncPayloadRecord); + return [handledError]; + } + } catch (error) { + asyncPayloadRecord.errors.push(error); + filterSubsequentPayloads(exeContext, path, asyncPayloadRecord); + return null; + } +} + +async function completedItemsFromPromisedCompletedItem( + exeContext: ExecutionContext, + itemType: GraphQLOutputType, + fieldNodes: ReadonlyArray, + path: Path, + itemPath: Path, + completedItem: Promise, + asyncPayloadRecord: AsyncPayloadRecord, +): Promise<[unknown] | null> { + try { + try { + return [await completedItem]; + } catch (rawError) { + const error = locatedError(rawError, fieldNodes, pathToArray(itemPath)); + const handledError = handleFieldError( + error, + itemType, + asyncPayloadRecord.errors, + ); + filterSubsequentPayloads(exeContext, itemPath, asyncPayloadRecord); + return [handledError]; + } + } catch (error) { + asyncPayloadRecord.errors.push(error); + filterSubsequentPayloads(exeContext, path, asyncPayloadRecord); + return null; + } +} + function executeStreamField( path: Path, itemPath: Path, @@ -1816,24 +1888,18 @@ function executeStreamField( exeContext, }); if (isPromise(item)) { - const completedItems = completePromisedValue( - exeContext, - itemType, - fieldNodes, - info, - itemPath, - item, - asyncPayloadRecord, - ).then( - (value) => [value], - (error) => { - asyncPayloadRecord.errors.push(error); - filterSubsequentPayloads(exeContext, path, asyncPayloadRecord); - return null; - }, + asyncPayloadRecord.addItems( + completedItemsFromPromisedItem( + exeContext, + itemType, + fieldNodes, + info, + path, + itemPath, + item, + asyncPayloadRecord, + ), ); - - asyncPayloadRecord.addItems(completedItems); return asyncPayloadRecord; } @@ -1866,27 +1932,17 @@ function executeStreamField( } if (isPromise(completedItem)) { - const completedItems = completedItem - .then(undefined, (rawError) => { - const error = locatedError(rawError, fieldNodes, pathToArray(itemPath)); - const handledError = handleFieldError( - error, - itemType, - asyncPayloadRecord.errors, - ); - filterSubsequentPayloads(exeContext, itemPath, asyncPayloadRecord); - return handledError; - }) - .then( - (value) => [value], - (error) => { - asyncPayloadRecord.errors.push(error); - filterSubsequentPayloads(exeContext, path, asyncPayloadRecord); - return null; - }, - ); - - asyncPayloadRecord.addItems(completedItems); + asyncPayloadRecord.addItems( + completedItemsFromPromisedCompletedItem( + exeContext, + itemType, + fieldNodes, + path, + itemPath, + completedItem, + asyncPayloadRecord, + ), + ); return asyncPayloadRecord; }