diff --git a/lib/internal/streams/iter/consumers.js b/lib/internal/streams/iter/consumers.js index 1162439bf88c3a..be289b1ef18cdf 100644 --- a/lib/internal/streams/iter/consumers.js +++ b/lib/internal/streams/iter/consumers.js @@ -48,6 +48,7 @@ const { const { concatBytes, + yieldAbortable, } = require('internal/streams/iter/utils'); const { @@ -121,7 +122,9 @@ async function collectAsync(source, signal, limit) { signal?.throwIfAborted(); // Normalize source via from() - accepts strings, ArrayBuffers, protocols, etc. - const normalized = from(source); + const abortableSource = signal && isAsyncIterable(source) ? + yieldAbortable(source, signal) : source; + const normalized = from(abortableSource); const chunks = []; // Fast path: no signal and no limit @@ -136,8 +139,9 @@ async function collectAsync(source, signal, limit) { // Slow path: with signal or limit checks let totalBytes = 0; + const iterable = signal ? yieldAbortable(normalized, signal) : normalized; - for await (const batch of normalized) { + for await (const batch of iterable) { signal?.throwIfAborted(); for (let i = 0; i < batch.length; i++) { const chunk = batch[i]; diff --git a/lib/internal/streams/iter/pull.js b/lib/internal/streams/iter/pull.js index f2ac4033dc051a..95ec2a084cdaf3 100644 --- a/lib/internal/streams/iter/pull.js +++ b/lib/internal/streams/iter/pull.js @@ -14,19 +14,12 @@ const { ArrayPrototypeSlice, PromisePrototypeThen, PromiseResolve, - PromiseWithResolvers, - SafePromisePrototypeFinally, - SafePromiseRace, SymbolAsyncIterator, SymbolIterator, TypedArrayPrototypeGetByteLength, Uint8Array, } = primordials; -const { - markPromiseAsHandled, -} = internalBinding('util'); - const { codes: { ERR_INVALID_ARG_TYPE, @@ -60,6 +53,7 @@ const { parsePullArgs, toUint8Array, wrapError, + yieldAbortable, } = require('internal/streams/iter/utils'); const { @@ -690,81 +684,6 @@ async function* applyValidatedStatefulAsyncTransform(source, transform, options) options.signal?.throwIfAborted(); } -function getOnAbort(reject, signal) { - return () => reject(signal.reason); -} - -/** - * Read one item from an async iterator, rejecting early if the signal aborts. - * @param {AsyncIterator} iterator - The iterator to read from. - * @param {AbortSignal|undefined} signal - Optional abort signal. - * @returns {Promise>|IteratorResult} - */ -function abortableNext(iterator, signal) { - if (signal === undefined) { - return iterator.next(); - } - - signal.throwIfAborted(); - - const next = iterator.next(); - const { promise, reject } = PromiseWithResolvers(); - const onAbort = getOnAbort(reject, signal); - signal.addEventListener('abort', onAbort, { __proto__: null, once: true }); - if (signal.aborted) { - onAbort(); - } - - return SafePromisePrototypeFinally(SafePromiseRace([next, promise]), () => { - signal.removeEventListener('abort', onAbort); - }); -} - -/** - * Wrap an async source so each pending read is abort-aware. - * @param {AsyncIterable} source - The source to read from. - * @param {AbortSignal|undefined} signal - Optional abort signal. - * @returns {AsyncIterable} - */ -function yieldAbortable(source, signal) { - if (signal === undefined) { - return source; - } - - return { - __proto__: null, - async *[SymbolAsyncIterator]() { - const iterator = source[SymbolAsyncIterator](); - let completed = false; - let aborted = false; - - try { - while (true) { - const { done, value } = await abortableNext(iterator, signal); - if (done) { - completed = true; - return; - } - signal.throwIfAborted(); - yield value; - } - } catch (error) { - aborted = signal.aborted; - throw error; - } finally { - if (!completed && typeof iterator.return === 'function') { - const result = iterator.return(); - if (aborted) { - markPromiseAsHandled(result); - } else { - await result; - } - } - } - }, - }; -} - /** * Create an async pipeline from source through transforms. * @yields {Uint8Array[]} diff --git a/lib/internal/streams/iter/utils.js b/lib/internal/streams/iter/utils.js index 7829afaade832f..ca3d2531d6f9e0 100644 --- a/lib/internal/streams/iter/utils.js +++ b/lib/internal/streams/iter/utils.js @@ -8,7 +8,11 @@ const { MathMin, NumberMAX_SAFE_INTEGER, PromiseResolve, + PromiseWithResolvers, + SafePromisePrototypeFinally, + SafePromiseRace, String, + SymbolAsyncIterator, TypedArrayPrototypeGetBuffer, TypedArrayPrototypeGetByteLength, TypedArrayPrototypeGetByteOffset, @@ -16,6 +20,10 @@ const { Uint8Array, } = primordials; +const { + markPromiseAsHandled, +} = internalBinding('util'); + const { TextEncoder } = require('internal/encoding'); const { codes: { @@ -69,6 +77,81 @@ function onSignalAbort(signal, handler) { } } +function getOnAbort(reject, signal) { + return () => reject(signal.reason); +} + +/** + * Read one item from an async iterator, rejecting early if the signal aborts. + * @param {AsyncIterator} iterator - The iterator to read from. + * @param {AbortSignal|undefined} signal - Optional abort signal. + * @returns {Promise>|IteratorResult} + */ +function abortableNext(iterator, signal) { + if (signal === undefined) { + return iterator.next(); + } + + signal.throwIfAborted(); + + const next = iterator.next(); + const { promise, reject } = PromiseWithResolvers(); + const onAbort = getOnAbort(reject, signal); + signal.addEventListener('abort', onAbort, { __proto__: null, once: true }); + if (signal.aborted) { + onAbort(); + } + + return SafePromisePrototypeFinally(SafePromiseRace([next, promise]), () => { + signal.removeEventListener('abort', onAbort); + }); +} + +/** + * Wrap an async source so each pending read is abort-aware. + * @param {AsyncIterable} source - The source to read from. + * @param {AbortSignal|undefined} signal - Optional abort signal. + * @returns {AsyncIterable} + */ +function yieldAbortable(source, signal) { + if (signal === undefined) { + return source; + } + + return { + __proto__: null, + async *[SymbolAsyncIterator]() { + const iterator = source[SymbolAsyncIterator](); + let completed = false; + let aborted = false; + + try { + while (true) { + const { done, value } = await abortableNext(iterator, signal); + if (done) { + completed = true; + return; + } + signal.throwIfAborted(); + yield value; + } + } catch (error) { + aborted = signal.aborted; + throw error; + } finally { + if (!completed && typeof iterator.return === 'function') { + const result = iterator.return(); + if (aborted) { + markPromiseAsHandled(result); + } else { + await result; + } + } + } + }, + }; +} + /** * Compute the minimum cursor across a set of consumers and count how many * consumers are at that cursor. @@ -301,4 +384,5 @@ module.exports = { toUint8Array, validateBackpressure, wrapError, + yieldAbortable, }; diff --git a/test/parallel/test-stream-iter-consumers-bytes.js b/test/parallel/test-stream-iter-consumers-bytes.js index e45ee991d587fd..53d0a3858f87ec 100644 --- a/test/parallel/test-stream-iter-consumers-bytes.js +++ b/test/parallel/test-stream-iter-consumers-bytes.js @@ -14,6 +14,7 @@ const { arrayBufferSync, array, arraySync, + toAsyncStreamable, } = require('stream/iter'); // ============================================================================= @@ -51,6 +52,55 @@ async function testBytesAsyncAbort() { ); } +async function testAsyncConsumersAbortPendingNext() { + const consumers = [ + ['bytes', bytes], + ['text', text], + ['arrayBuffer', arrayBuffer], + ['array', array], + ]; + + for (const [name, consumer] of consumers) { + const ac = new AbortController(); + const reason = new Error(`${name} boom`); + + async function* never() { + await new Promise(() => {}); + yield []; + } + + const promise = consumer(never(), { __proto__: null, signal: ac.signal }); + ac.abort(reason); + + await assert.rejects(promise, reason); + } +} + +async function testAsyncConsumersAbortPendingNormalization() { + const consumers = [ + ['bytes', bytes], + ['text', text], + ['arrayBuffer', arrayBuffer], + ['array', array], + ]; + + for (const [name, consumer] of consumers) { + const ac = new AbortController(); + const reason = new Error(`${name} normalization boom`); + const source = { + __proto__: null, + [toAsyncStreamable]() { + return new Promise(() => {}); + }, + }; + + const promise = consumer(source, { __proto__: null, signal: ac.signal }); + ac.abort(reason); + + await assert.rejects(promise, reason); + } +} + async function testBytesEmpty() { const data = await bytes(from([])); assert.ok(data instanceof Uint8Array); @@ -203,6 +253,8 @@ Promise.all([ testBytesAsync(), testBytesAsyncLimit(), testBytesAsyncAbort(), + testAsyncConsumersAbortPendingNext(), + testAsyncConsumersAbortPendingNormalization(), testBytesEmpty(), testArrayBufferSyncBasic(), testArrayBufferAsync(),