豆豆友情提示:这是一个非官方 GitHub 代理镜像,主要用于网络测试或访问加速。请勿在此进行登录、注册或处理任何敏感信息。进行这些操作请务必访问官方网站 github.com。 Raw 内容也通过此代理提供。
Skip to content

Commit e31d0bb

Browse files
MarkDaoustcopybara-github
authored andcommitted
chore: Optimize streaming
This CL improves the efficiency of the streaming LineDecoder and SSE message iterator, specifically addressing performance degradation when processing large or highly fragmented payloads. Key changes: * Optimized Buffer Scanning: Introduced searchIndex in LineDecoder to track the last searched position. This prevents the decoder from rescanning the entire accumulated buffer from the beginning for every new chunk, reducing complexity from $O(n^2)$ to $O(n)$. * Faster Pattern Matching: Replaced manual byte-by-byte loops with Uint8Array.prototype.indexOf in findNewlineIndex and findDoubleNewlineIndex for better performance. * Simplified Streaming Logic: Replaced iterSSEChunks with iterBinaryChunks in streaming.ts. This removes redundant buffering logic, as the stateful LineDecoder already handles fragmented line assembly. * Boundary Safety: Implemented a 1-byte search overlap (length - 1) to correctly detect multi-byte newline sequences (like \r\n) when they are split across chunk boundaries. * Regression Testing: Added unit tests in api_client_test.ts and interactions_test.ts to verify that large, fragmented payloads are handled correctly. PiperOrigin-RevId: 878636096
1 parent a2ada3f commit e31d0bb

File tree

5 files changed

+196
-42
lines changed

5 files changed

+196
-42
lines changed

src/interactions/core/streaming.ts

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import { GeminiNextGenAPIClientError } from './error.js';
88
import { type ReadableStream } from '../internal/shim-types.js';
99
import { makeReadableStream } from '../internal/shims.js';
10-
import { findDoubleNewlineIndex, LineDecoder } from '../internal/decoders/line.js';
10+
import { LineDecoder } from '../internal/decoders/line.js';
1111
import { ReadableStreamToAsyncIterable } from '../internal/shims.js';
1212
import { isAbortError } from '../internal/errors.js';
1313
import { encodeUTF8 } from '../internal/utils/bytes.js';
@@ -222,7 +222,7 @@ export async function* _iterSSEMessages(
222222
const lineDecoder = new LineDecoder();
223223

224224
const iter = ReadableStreamToAsyncIterable<Bytes>(response.body);
225-
for await (const sseChunk of iterSSEChunks(iter)) {
225+
for await (const sseChunk of iterBinaryChunks(iter)) {
226226
for (const line of lineDecoder.decode(sseChunk)) {
227227
const sse = sseDecoder.decode(line);
228228
if (sse) yield sse;
@@ -236,12 +236,10 @@ export async function* _iterSSEMessages(
236236
}
237237

238238
/**
239-
* Given an async iterable iterator, iterates over it and yields full
240-
* SSE chunks, i.e. yields when a double new-line is encountered.
239+
* Given an async iterable iterator, normalizes each chunk to a
240+
* Uint8Array and yields it.
241241
*/
242-
async function* iterSSEChunks(iterator: AsyncIterableIterator<Bytes>): AsyncGenerator<Uint8Array> {
243-
let data = new Uint8Array();
244-
242+
async function* iterBinaryChunks(iterator: AsyncIterableIterator<Bytes>): AsyncGenerator<Uint8Array> {
245243
for await (const chunk of iterator) {
246244
if (chunk == null) {
247245
continue;
@@ -252,20 +250,7 @@ async function* iterSSEChunks(iterator: AsyncIterableIterator<Bytes>): AsyncGene
252250
: typeof chunk === 'string' ? encodeUTF8(chunk)
253251
: chunk;
254252

255-
let newData = new Uint8Array(data.length + binaryChunk.length);
256-
newData.set(data);
257-
newData.set(binaryChunk, data.length);
258-
data = newData;
259-
260-
let patternIndex;
261-
while ((patternIndex = findDoubleNewlineIndex(data)) !== -1) {
262-
yield data.slice(0, patternIndex);
263-
data = data.slice(patternIndex);
264-
}
265-
}
266-
267-
if (data.length > 0) {
268-
yield data;
253+
yield binaryChunk;
269254
}
270255
}
271256

src/interactions/internal/decoders/line.ts

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ export class LineDecoder {
2121

2222
private buffer: Uint8Array;
2323
private carriageReturnIndex: number | null;
24+
private searchIndex: number;
2425

2526
constructor() {
2627
this.buffer = new Uint8Array();
2728
this.carriageReturnIndex = null;
29+
this.searchIndex = 0;
2830
}
2931

3032
decode(chunk: Bytes): string[] {
@@ -41,7 +43,9 @@ export class LineDecoder {
4143

4244
const lines: string[] = [];
4345
let patternIndex;
44-
while ((patternIndex = findNewlineIndex(this.buffer, this.carriageReturnIndex)) != null) {
46+
while (
47+
(patternIndex = findNewlineIndex(this.buffer, this.carriageReturnIndex ?? this.searchIndex)) != null
48+
) {
4549
if (patternIndex.carriage && this.carriageReturnIndex == null) {
4650
// skip until we either get a corresponding `\n`, a new `\r` or nothing
4751
this.carriageReturnIndex = patternIndex.index;
@@ -56,6 +60,7 @@ export class LineDecoder {
5660
lines.push(decodeUTF8(this.buffer.subarray(0, this.carriageReturnIndex - 1)));
5761
this.buffer = this.buffer.subarray(this.carriageReturnIndex);
5862
this.carriageReturnIndex = null;
63+
this.searchIndex = 0;
5964
continue;
6065
}
6166

@@ -67,8 +72,11 @@ export class LineDecoder {
6772

6873
this.buffer = this.buffer.subarray(patternIndex.index);
6974
this.carriageReturnIndex = null;
75+
this.searchIndex = 0;
7076
}
7177

78+
this.searchIndex = Math.max(0, this.buffer.length - 1);
79+
7280
return lines;
7381
}
7482

@@ -96,45 +104,74 @@ function findNewlineIndex(
96104
const newline = 0x0a; // \n
97105
const carriage = 0x0d; // \r
98106

99-
for (let i = startIndex ?? 0; i < buffer.length; i++) {
100-
if (buffer[i] === newline) {
101-
return { preceding: i, index: i + 1, carriage: false };
102-
}
107+
const start = startIndex ?? 0;
108+
const nextNewline = buffer.indexOf(newline, start);
109+
const nextCarriage = buffer.indexOf(carriage, start);
103110

104-
if (buffer[i] === carriage) {
105-
return { preceding: i, index: i + 1, carriage: true };
106-
}
111+
if (nextNewline === -1 && nextCarriage === -1) {
112+
return null;
113+
}
114+
115+
let i: number;
116+
if (nextNewline !== -1 && nextCarriage !== -1) {
117+
i = Math.min(nextNewline, nextCarriage);
118+
} else {
119+
i = nextNewline !== -1 ? nextNewline : nextCarriage;
120+
}
121+
122+
if (buffer[i] === newline) {
123+
return { preceding: i, index: i + 1, carriage: false };
107124
}
108125

109-
return null;
126+
return { preceding: i, index: i + 1, carriage: true };
110127
}
111128

112-
export function findDoubleNewlineIndex(buffer: Uint8Array): number {
129+
export function findDoubleNewlineIndex(buffer: Uint8Array, startIndex: number = 0): number {
113130
// This function searches the buffer for the end patterns (\r\r, \n\n, \r\n\r\n)
114131
// and returns the index right after the first occurrence of any pattern,
115132
// or -1 if none of the patterns are found.
116133
const newline = 0x0a; // \n
117134
const carriage = 0x0d; // \r
118135

119-
for (let i = 0; i < buffer.length - 1; i++) {
120-
if (buffer[i] === newline && buffer[i + 1] === newline) {
136+
let i = startIndex;
137+
while (i < buffer.length - 1) {
138+
const nextNewline = buffer.indexOf(newline, i);
139+
const nextCarriage = buffer.indexOf(carriage, i);
140+
141+
if (nextNewline === -1 && nextCarriage === -1) {
142+
return -1;
143+
}
144+
145+
let index: number;
146+
if (nextNewline !== -1 && nextCarriage !== -1) {
147+
index = Math.min(nextNewline, nextCarriage);
148+
} else {
149+
index = nextNewline !== -1 ? nextNewline : nextCarriage;
150+
}
151+
152+
if (index >= buffer.length - 1) {
153+
return -1;
154+
}
155+
156+
if (buffer[index] === newline && buffer[index + 1] === newline) {
121157
// \n\n
122-
return i + 2;
158+
return index + 2;
123159
}
124-
if (buffer[i] === carriage && buffer[i + 1] === carriage) {
160+
if (buffer[index] === carriage && buffer[index + 1] === carriage) {
125161
// \r\r
126-
return i + 2;
162+
return index + 2;
127163
}
128164
if (
129-
buffer[i] === carriage &&
130-
buffer[i + 1] === newline &&
131-
i + 3 < buffer.length &&
132-
buffer[i + 2] === carriage &&
133-
buffer[i + 3] === newline
165+
buffer[index] === carriage &&
166+
buffer[index + 1] === newline &&
167+
index + 3 < buffer.length &&
168+
buffer[index + 2] === carriage &&
169+
buffer[index + 3] === newline
134170
) {
135171
// \r\n\r\n
136-
return i + 4;
172+
return index + 4;
137173
}
174+
i = index + 1;
138175
}
139176

140177
return -1;

test/unit/api_client_test.ts

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,38 @@ describe('processStreamResponse', () => {
381381
const result = await resultHttpResponse.value.json();
382382
expect(result).toEqual(expectedResponse);
383383
});
384+
385+
it('should handle large fragmented payloads correctly (regression for $O(n^2)$)', async () => {
386+
const largeData = 'A'.repeat(10 * 1024);
387+
const jsonPayload = JSON.stringify({data: largeData});
388+
const ssePayload = `data: ${jsonPayload}\n\n`;
389+
390+
const stream = new Readable();
391+
const chunkSize = 1024;
392+
for (let i = 0; i < ssePayload.length; i += chunkSize) {
393+
stream.push(ssePayload.substring(i, i + chunkSize));
394+
}
395+
stream.push(null);
396+
397+
const readableStream = new ReadableStream({
398+
start(controller) {
399+
stream.on('data', (chunk) =>
400+
controller.enqueue(new TextEncoder().encode(chunk)),
401+
);
402+
stream.on('end', () => controller.close());
403+
},
404+
});
405+
const response = new Response(readableStream);
406+
407+
const generator = apiClient.processStreamResponse(response);
408+
const result = await generator.next();
409+
expect(result.done).toBeFalse();
410+
const json = await result.value.json();
411+
expect(json.data).toBe(largeData);
412+
413+
const final = await generator.next();
414+
expect(final.done).toBeTrue();
415+
});
384416
});
385417

386418
describe('ApiClient', () => {

test/unit/interactions_test.ts

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,4 +196,56 @@ describe('Interactions resource', () => {
196196
expect(headers.get('x-goog-api-key')).toBe('some-manual-key');
197197
});
198198
});
199+
200+
describe('streaming regression', () => {
201+
let client: GeminiNextGenAPIClient;
202+
let fetchSpy: jasmine.Spy<Fetch>;
203+
204+
beforeEach(() => {
205+
client = new GeminiNextGenAPIClient({
206+
clientAdapter,
207+
apiKey: 'test-api-key',
208+
});
209+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
210+
fetchSpy = spyOn(client as any, 'fetch');
211+
});
212+
213+
it('should handle large fragmented SSE payloads correctly', async () => {
214+
const largeData = 'A'.repeat(10 * 1024);
215+
const mockSSE =
216+
`data: {"event_type": "content.delta", "delta": {"type": "text", "text": "${
217+
largeData
218+
}"}}\n\n` + `data: [DONE]\n\n`;
219+
const sseBytes = new TextEncoder().encode(mockSSE);
220+
221+
const readableStream = new ReadableStream({
222+
start(controller) {
223+
const chunkSize = 1024;
224+
for (let i = 0; i < sseBytes.length; i += chunkSize) {
225+
controller.enqueue(sseBytes.subarray(i, i + chunkSize));
226+
}
227+
controller.close();
228+
},
229+
});
230+
231+
fetchSpy.and.resolveTo(new Response(readableStream));
232+
233+
const stream = await client.interactions.create({
234+
model: 'gemini-pro',
235+
input: 'test',
236+
stream: true,
237+
});
238+
239+
let received = '';
240+
for await (const event of stream) {
241+
if (
242+
event.event_type === 'content.delta' &&
243+
event.delta?.type === 'text'
244+
) {
245+
received += event.delta.text;
246+
}
247+
}
248+
expect(received).toBe(largeData);
249+
});
250+
});
199251
});

test/unit/models_test.ts

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,3 +1402,51 @@ describe('generateContentStream', () => {
14021402
}
14031403
});
14041404
});
1405+
1406+
describe('generateContentStream regression', () => {
1407+
it('should handle large fragmented SSE payloads correctly', async () => {
1408+
const client = new GoogleGenAI({vertexai: false, apiKey: 'fake-api-key'});
1409+
1410+
const largeText = 'A'.repeat(10 * 1024);
1411+
const mockResponse = {
1412+
candidates: [
1413+
{
1414+
content: {
1415+
parts: [{text: largeText}],
1416+
role: 'model',
1417+
},
1418+
},
1419+
],
1420+
};
1421+
1422+
const sseData = `data: ${JSON.stringify(mockResponse)}\n\n`;
1423+
const encoder = new TextEncoder();
1424+
const sseBytes = encoder.encode(sseData);
1425+
1426+
const readableStream = new ReadableStream({
1427+
start(controller) {
1428+
const chunkSize = 1024;
1429+
for (let i = 0; i < sseBytes.length; i += chunkSize) {
1430+
controller.enqueue(sseBytes.subarray(i, i + chunkSize));
1431+
}
1432+
controller.close();
1433+
},
1434+
});
1435+
1436+
spyOn(global, 'fetch').and.resolveTo(
1437+
new Response(readableStream, fetchOkOptions),
1438+
);
1439+
1440+
const stream = await client.models.generateContentStream({
1441+
model: 'gemini-pro',
1442+
contents: 'test',
1443+
});
1444+
1445+
let receivedText = '';
1446+
for await (const chunk of stream) {
1447+
receivedText += chunk.candidates?.[0]?.content?.parts?.[0]?.text ?? '';
1448+
}
1449+
1450+
expect(receivedText).toBe(largeText);
1451+
});
1452+
});

0 commit comments

Comments
 (0)