diff --git a/lib/api/api-request.js b/lib/api/api-request.js index e5d598aa6dd..073d7096595 100644 --- a/lib/api/api-request.js +++ b/lib/api/api-request.js @@ -6,7 +6,6 @@ const { InvalidArgumentError } = require('../core/errors') const util = require('../core/util') const { getResolveErrorBodyCallback } = require('./util') const { AsyncResource } = require('node:async_hooks') -const { addSignal, removeSignal } = require('./abort-signal') class RequestHandler extends AsyncResource { constructor (opts, callback) { @@ -56,19 +55,18 @@ class RequestHandler extends AsyncResource { this.onInfo = onInfo || null this.throwOnError = throwOnError this.highWaterMark = highWaterMark + this.signal = signal if (util.isStream(body)) { body.on('error', (err) => { this.onError(err) }) } - - addSignal(this, signal) } onConnect (abort, context) { - if (this.reason) { - abort(this.reason) + if (this.signal.aborted) { + abort(this.signal.reason) return } @@ -76,6 +74,16 @@ class RequestHandler extends AsyncResource { this.abort = abort this.context = context + + if (this.signal) { + this.removeAbortListener = util.addAbortListener(this.signal, () => { + if (this.res) { + this.res.destroy(this.signal.reason) + } else { + this.abort(this.signal.reason) + } + }) + } } onHeaders (statusCode, rawHeaders, resume, statusMessage) { @@ -95,6 +103,10 @@ class RequestHandler extends AsyncResource { const contentLength = parsedHeaders['content-length'] const body = new Readable({ resume, abort, contentType, contentLength, highWaterMark }) + if (this.removeAbortListener) { + body.on('close', this.removeAbortListener) + } + this.callback = null this.res = body if (callback !== null) { @@ -123,8 +135,6 @@ class RequestHandler extends AsyncResource { onComplete (trailers) { const { res } = this - removeSignal(this) - util.parseHeaders(trailers, this.trailers) res.push(null) @@ -133,8 +143,6 @@ class RequestHandler extends AsyncResource { onError (err) { const { res, callback, body, opaque } = this - removeSignal(this) - if (callback) { // TODO: Does this need queueMicrotask? this.callback = null @@ -149,6 +157,8 @@ class RequestHandler extends AsyncResource { queueMicrotask(() => { util.destroy(res, err) }) + } else if (this.removeAbortListener) { + this.removeAbortListener() } if (body) { diff --git a/test/request-signal.js b/test/request-signal.js new file mode 100644 index 00000000000..fd4d2f885a5 --- /dev/null +++ b/test/request-signal.js @@ -0,0 +1,76 @@ +'use strict' + +const { createServer } = require('node:http') +const { test, after } = require('node:test') +const { tspl } = require('@matteo.collina/tspl') +const { request } = require('..') + +test('pre abort signal w/ reason', async (t) => { + t = tspl(t, { plan: 1 }) + + const server = createServer((req, res) => { + res.end('asd') + }) + after(() => server.close()) + + server.listen(0, async () => { + const ac = new AbortController() + const _err = new Error() + ac.abort(_err) + try { + await request(`http://0.0.0.0:${server.address().port}`, { signal: ac.signal }) + } catch (err) { + t.equal(err, _err) + } + }) + await t.completed +}) + +test('post abort signal', async (t) => { + t = tspl(t, { plan: 1 }) + + const server = createServer((req, res) => { + res.end('asd') + }) + after(() => server.close()) + + server.listen(0, async () => { + const ac = new AbortController() + const ures = await request(`http://0.0.0.0:${server.address().port}`, { signal: ac.signal }) + ac.abort() + try { + /* eslint-disable-next-line no-unused-vars */ + for await (const chunk of ures.body) { + // Do nothing... + } + } catch (err) { + t.equal(err.name, 'AbortError') + } + }) + await t.completed +}) + +test('post abort signal w/ reason', async (t) => { + t = tspl(t, { plan: 1 }) + + const server = createServer((req, res) => { + res.end('asd') + }) + after(() => server.close()) + + server.listen(0, async () => { + const ac = new AbortController() + const _err = new Error() + const ures = await request(`http://0.0.0.0:${server.address().port}`, { signal: ac.signal }) + ac.abort(_err) + try { + /* eslint-disable-next-line no-unused-vars */ + for await (const chunk of ures.body) { + // Do nothing... + } + } catch (err) { + t.equal(err, _err) + } + }) + await t.completed +})