diff --git a/src/connection.ts b/src/connection.ts index 9f3875a3..3d4c23e6 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -30,7 +30,7 @@ import { PublishErrorListener, ResponseDecoder, } from "./response_decoder" -import { removeFrom } from "./util" +import { DEFAULT_MAX_FRAME_SIZE, DEFAULT_UNLIMITED_FRAME_SIZE, removeFrom } from "./util" import { WaitingResponse } from "./waiting_response" import { SubscribeResponse } from "./responses/subscribe_response" import { TuneResponse } from "./responses/tune_response" @@ -62,7 +62,7 @@ export class Connection { private consumers = new Map() private compressions = new Map() - constructor(private readonly logger: Logger) { + constructor(private readonly logger: Logger, private frameMax: number = DEFAULT_MAX_FRAME_SIZE) { this.heartbeat = new Heartbeat(this, this.logger) this.compressions.set(CompressionType.None, NoneCompression.create()) this.compressions.set(CompressionType.Gzip, GzipCompression.create()) @@ -89,7 +89,7 @@ export class Connection { } static connect(params: ConnectionParams, logger?: Logger): Promise { - return new Connection(logger ?? new NullLogger()).start(params) + return new Connection(logger ?? new NullLogger(), params.frameMax).start(params) } public start(params: ConnectionParams): Promise { @@ -235,6 +235,10 @@ export class Connection { return this.consumers.size } + public savedFrameMax() { + return this.frameMax + } + public send(cmd: Request): Promise { return new Promise((res, rej) => { const body = cmd.toBuffer() @@ -326,7 +330,8 @@ export class Connection { const heartbeat = extractHeartbeatInterval(heartbeatInterval, tuneResponse) return new Promise((res, rej) => { - const request = new TuneRequest({ frameMax: tuneResponse.frameMax, heartbeat }) + this.frameMax = this.calculateFrameMaxSizeFrom(tuneResponse.frameMax) + const request = new TuneRequest({ frameMax: this.frameMax, heartbeat }) this.socket.write(request.toBuffer(), (err) => { this.logger.debug(`Write COMPLETED for cmd TUNE: ${inspect(tuneResponse)} - err: ${err}`) return err ? rej(err) : res({ heartbeat }) @@ -463,6 +468,13 @@ export class Connection { response.messages.map((x) => consumer.handle(x)) }) } + + private calculateFrameMaxSizeFrom(tuneResponseFrameMax: number) { + if (this.frameMax === DEFAULT_UNLIMITED_FRAME_SIZE && tuneResponseFrameMax === DEFAULT_UNLIMITED_FRAME_SIZE) { + return 0 + } + return Math.max(1, Math.min(this.frameMax, tuneResponseFrameMax)) + } } export type ListenersParams = { diff --git a/src/util.ts b/src/util.ts index f1c98a98..d67296ad 100644 --- a/src/util.ts +++ b/src/util.ts @@ -31,3 +31,6 @@ export function range(count: number): number[] { } return ret } + +export const DEFAULT_MAX_FRAME_SIZE = 1048576 +export const DEFAULT_UNLIMITED_FRAME_SIZE = 0 diff --git a/test/e2e/connect_frame_size_negotiation.test.ts b/test/e2e/connect_frame_size_negotiation.test.ts new file mode 100644 index 00000000..a3a818fa --- /dev/null +++ b/test/e2e/connect_frame_size_negotiation.test.ts @@ -0,0 +1,38 @@ +import { expect } from "chai" +import { createConnection } from "../support/fake_data" +import { Rabbit } from "../support/rabbit" +import { eventually, username, password } from "../support/util" + +describe("connect frame size negotiation", () => { + const rabbit = new Rabbit(username, password) + + it("using 65536 as frameMax", async () => { + const frameMax = 65536 + + const connection = await createConnection(username, password, undefined, frameMax) + + await eventually(async () => { + expect(connection.savedFrameMax()).lte(frameMax) + expect(await rabbit.getConnections()).lengthOf(1) + }, 5000) + try { + await connection.close() + await rabbit.closeAllConnections() + } catch (e) {} + }).timeout(10000) + + it("using 1024 as frameMax", async () => { + const frameMax = 1024 + + const connection = await createConnection(username, password, undefined, frameMax) + + await eventually(async () => { + expect(connection.savedFrameMax()).lte(frameMax) + expect(await rabbit.getConnections()).lengthOf(1) + }, 5000) + try { + await connection.close() + await rabbit.closeAllConnections() + } catch (e) {} + }).timeout(10000) +}) diff --git a/test/support/fake_data.ts b/test/support/fake_data.ts index 5617856e..a21ff786 100644 --- a/test/support/fake_data.ts +++ b/test/support/fake_data.ts @@ -32,14 +32,14 @@ export async function createPublisher(streamName: string, connection: Connection return publisher } -export function createConnection(username: string, password: string, listeners?: ListenersParams) { +export function createConnection(username: string, password: string, listeners?: ListenersParams, frameMax?: number) { return connect({ hostname: "localhost", port: 5552, username, password, vhost: "/", - frameMax: 0, // not used + frameMax: frameMax ?? 0, // not used heartbeat: 0, listeners: listeners, })