From 1248aa99bdd39424b514614da91666f4c029c393 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filipe=20Caba=C3=A7o?= Date: Mon, 7 Apr 2025 23:07:37 +0100 Subject: [PATCH 1/4] fix: Add heartbeat callback; Move to Sets vs Arrays --- src/RealtimeClient.ts | 48 ++++++++++++++++++------ test/channel.test.ts | 13 ++++--- test/client.test.ts | 85 +++++++++++++++++++++++++++++++------------ 3 files changed, 104 insertions(+), 42 deletions(-) diff --git a/src/RealtimeClient.ts b/src/RealtimeClient.ts index a0f5854..9a75131 100755 --- a/src/RealtimeClient.ts +++ b/src/RealtimeClient.ts @@ -61,6 +61,7 @@ export type RealtimeClientOptions = { transport?: WebSocketLikeConstructor timeout?: number heartbeatIntervalMs?: number + heartbeatCallback?: Function logger?: Function encode?: Function decode?: Function @@ -86,7 +87,7 @@ const WORKER_SCRIPT = ` export default class RealtimeClient { accessTokenValue: string | null = null apiKey: string | null = null - channels: RealtimeChannel[] = [] + channels: Set = new Set() endPoint: string = '' httpEndpoint: string = '' headers?: { [key: string]: string } = DEFAULT_HEADERS @@ -96,6 +97,7 @@ export default class RealtimeClient { heartbeatIntervalMs: number = 25000 heartbeatTimer: ReturnType | undefined = undefined pendingHeartbeatRef: string | null = null + heartbeatCallback: Function = noop ref: number = 0 reconnectTimer: Timer logger: Function = noop @@ -133,6 +135,7 @@ export default class RealtimeClient { * @param options.params The optional params to pass when connecting. * @param options.headers The optional headers to pass when connecting. * @param options.heartbeatIntervalMs The millisec interval to send a heartbeat message. + * @param options.heartbeatCallback Function to be executed on heartbeat check. * @param options.logger The optional function for specialized logging, ie: logger: (kind, msg, data) => { console.log(`${kind}: ${msg}`, data) } * @param options.logLevel Sets the log level for Realtime * @param options.encode The function to encode outgoing messages. Defaults to JSON: (payload, callback) => callback(JSON.stringify(payload)) @@ -161,6 +164,9 @@ export default class RealtimeClient { if (options?.heartbeatIntervalMs) this.heartbeatIntervalMs = options.heartbeatIntervalMs + if (options?.heartbeatCallback) + this.heartbeatCallback = options.heartbeatCallback + const accessTokenValue = options?.params?.apikey if (accessTokenValue) { this.accessTokenValue = accessTokenValue @@ -268,7 +274,7 @@ export default class RealtimeClient { * Returns all created channels */ getChannels(): RealtimeChannel[] { - return this.channels + return Array.from(this.channels) } /** @@ -279,7 +285,7 @@ export default class RealtimeClient { channel: RealtimeChannel ): Promise { const status = await channel.unsubscribe() - if (this.channels.length === 0) { + if (this.channels.values.length === 0) { this.disconnect() } return status @@ -290,9 +296,13 @@ export default class RealtimeClient { */ async removeAllChannels(): Promise { const values_1 = await Promise.all( - this.channels.map((channel) => channel.unsubscribe()) + Array.from(this.channels).map((channel) => { + this.channels.delete(channel) + return channel.unsubscribe() + }) ) this.disconnect() + return values_1 } @@ -332,9 +342,18 @@ export default class RealtimeClient { topic: string, params: RealtimeChannelOptions = { config: {} } ): RealtimeChannel { - const chan = new RealtimeChannel(`realtime:${topic}`, params, this) - this.channels.push(chan) - return chan + const realtimeTopic = `realtime:${topic}` + const exists = this.getChannels().find( + (c: RealtimeChannel) => c.topic === realtimeTopic + ) + + if (!exists) { + const chan = new RealtimeChannel(`realtime:${topic}`, params, this) + this.channels.add(chan) + return chan + } else { + return exists + } } /** @@ -412,6 +431,7 @@ export default class RealtimeClient { payload: {}, ref: this.pendingHeartbeatRef, }) + await this.setAuth() } @@ -467,7 +487,7 @@ export default class RealtimeClient { * @internal */ _leaveOpenTopic(topic: string): void { - let dupChannel = this.channels.find( + let dupChannel = Array.from(this.channels).find( (c) => c.topic === topic && (c._isJoined() || c._isJoining()) ) if (dupChannel) { @@ -484,9 +504,7 @@ export default class RealtimeClient { * @internal */ _remove(channel: RealtimeChannel) { - this.channels = this.channels.filter( - (c: RealtimeChannel) => c._joinRef() !== channel._joinRef() - ) + this.channels.delete(channel) } /** @@ -510,6 +528,10 @@ export default class RealtimeClient { this.decode(rawMessage.data, (msg: RealtimeMessage) => { let { topic, event, payload, ref } = msg + if (topic === 'phoenix' && event === 'phx_reply') { + this.heartbeatCallback(msg) + } + if (ref && ref === this.pendingHeartbeatRef) { this.pendingHeartbeatRef = null } @@ -521,11 +543,13 @@ export default class RealtimeClient { }`, payload ) - this.channels + + Array.from(this.channels) .filter((channel: RealtimeChannel) => channel._isMember(topic)) .forEach((channel: RealtimeChannel) => channel._trigger(event, payload, ref) ) + this.stateChangeCallbacks.message.forEach((callback) => callback(msg)) }) } diff --git a/test/channel.test.ts b/test/channel.test.ts index f516b11..71a449c 100644 --- a/test/channel.test.ts +++ b/test/channel.test.ts @@ -9,6 +9,7 @@ import { beforeAll, afterAll, vi, + it, } from 'vitest' import RealtimeClient from '../src/RealtimeClient' @@ -814,12 +815,12 @@ describe('onClose', () => { }) test('removes channel from socket', () => { - assert.equal(socket.channels.length, 1) - assert.deepEqual(socket.channels[0], channel) + assert.equal(socket.getChannels().length, 1) + assert.deepEqual(socket.getChannels()[0], channel) channel._trigger('phx_close') - assert.equal(socket.channels.length, 0) + assert.equal(socket.getChannels().length, 0) }) }) @@ -1113,13 +1114,13 @@ describe('leave', () => { test("closes channel on 'ok' from server", () => { const anotherChannel = socket.channel('another', { three: 'four' }) - assert.equal(socket.channels.length, 2) + assert.equal(socket.getChannels().length, 2) channel.unsubscribe() channel.joinPush.trigger('ok', {}) - assert.equal(socket.channels.length, 1) - assert.deepEqual(socket.channels[0], anotherChannel) + assert.equal(socket.getChannels().length, 1) + assert.deepEqual(socket.getChannels()[0], anotherChannel) }) test("sets state to closed on 'ok' event", () => { diff --git a/test/client.test.ts b/test/client.test.ts index 29c90bb..bbeeaa9 100644 --- a/test/client.test.ts +++ b/test/client.test.ts @@ -43,7 +43,7 @@ describe('constructor', () => { test('sets defaults', () => { let socket = new RealtimeClient(url) - assert.equal(socket.channels.length, 0) + assert.equal(socket.getChannels().length, 0) assert.equal(socket.sendBuffer.length, 0) assert.equal(socket.ref, 0) assert.equal(socket.endPoint, `${url}/websocket`) @@ -218,6 +218,7 @@ describe('channel', () => { one: 'two', }) }) + test('returns channel with given topic and params for a private channel', () => { channel = socket.channel('topic', { config: { private: true }, one: 'two' }) @@ -232,17 +233,29 @@ describe('channel', () => { one: 'two', }) }) + test('adds channel to sockets channels list', () => { - assert.equal(socket.channels.length, 0) + assert.equal(socket.getChannels().length, 0) channel = socket.channel('topic') - assert.equal(socket.channels.length, 1) + assert.equal(socket.getChannels().length, 1) const [foundChannel] = socket.channels assert.deepStrictEqual(foundChannel, channel) }) + test('does not repeat channels to sockets channels list', () => { + assert.equal(socket.getChannels().length, 0) + + channel = socket.channel('topic') + socket.channel('topic') // should be ignored + + assert.equal(socket.getChannels().length, 1) + + const [foundChannel] = socket.channels + assert.deepStrictEqual(foundChannel, channel) + }) test('gets all channels', () => { assert.equal(socket.getChannels().length, 0) @@ -258,12 +271,12 @@ describe('channel', () => { channel = socket.channel('topic').subscribe() - assert.equal(socket.channels.length, 1) + assert.equal(socket.getChannels().length, 1) assert.ok(connectStub.called) await socket.removeChannel(channel) - assert.equal(socket.channels.length, 0) + assert.equal(socket.getChannels().length, 0) assert.ok(disconnectStub.called) }) @@ -273,11 +286,11 @@ describe('channel', () => { socket.channel('chan1').subscribe() socket.channel('chan2').subscribe() - assert.equal(socket.channels.length, 2) + assert.equal(socket.getChannels().length, 2) await socket.removeAllChannels() - assert.equal(socket.channels.length, 0) + assert.equal(socket.getChannels().length, 0) assert.ok(disconnectStub.called) }) }) @@ -295,10 +308,18 @@ describe('leaveOpenTopic', () => { channel1 = socket.channel('topic') channel2 = socket.channel('topic') channel1.subscribe() - channel2.subscribe() - - assert.equal(socket.channels.length, 1) - assert.equal(socket.channels[0].topic, 'realtime:topic') + try { + channel2.subscribe() + } catch (e) { + console.error(e) + assert.equal( + e, + `tried to subscribe multiple times. 'subscribe' can only be called a single time per channel instance` + ) + } + + assert.equal(socket.getChannels().length, 1) + assert.equal(socket.getChannels()[0].topic, 'realtime:topic') }) }) @@ -312,7 +333,7 @@ describe('remove', () => { socket._remove(channel1) - assert.equal(socket.channels.length, 1) + assert.equal(socket.getChannels().length, 1) const [foundChannel] = socket.channels assert.deepStrictEqual(foundChannel, channel2) @@ -379,9 +400,9 @@ describe('setAuth', () => { }) test("sets access token, updates channels' join payload, and pushes token to channels", async () => { - const channel1 = socket.channel('test-topic') - const channel2 = socket.channel('test-topic') - const channel3 = socket.channel('test-topic') + const channel1 = socket.channel('test-topic1') + const channel2 = socket.channel('test-topic2') + const channel3 = socket.channel('test-topic3') channel1.state = CHANNEL_STATES.joined channel2.state = CHANNEL_STATES.closed @@ -432,9 +453,9 @@ describe('setAuth', () => { }) test("sets access token, updates channels' join payload, and pushes token to channels if is not a jwt", async () => { - const channel1 = socket.channel('test-topic') - const channel2 = socket.channel('test-topic') - const channel3 = socket.channel('test-topic') + const channel1 = socket.channel('test-topic1') + const channel2 = socket.channel('test-topic2') + const channel3 = socket.channel('test-topic3') channel1.state = CHANNEL_STATES.joined channel2.state = CHANNEL_STATES.closed @@ -474,9 +495,9 @@ describe('setAuth', () => { accessToken: () => Promise.resolve(token), }) - const channel1 = new_socket.channel('test-topic') - const channel2 = new_socket.channel('test-topic') - const channel3 = new_socket.channel('test-topic') + const channel1 = new_socket.channel('test-topic1') + const channel2 = new_socket.channel('test-topic2') + const channel3 = new_socket.channel('test-topic3') channel1.state = CHANNEL_STATES.joined channel2.state = CHANNEL_STATES.closed @@ -508,9 +529,9 @@ describe('setAuth', () => { }) test("overrides access token, updates channels' join payload, and pushes token to channels", () => { - const channel1 = socket.channel('test-topic') - const channel2 = socket.channel('test-topic') - const channel3 = socket.channel('test-topic') + const channel1 = socket.channel('test-topic1') + const channel2 = socket.channel('test-topic2') + const channel3 = socket.channel('test-topic3') channel1.state = CHANNEL_STATES.joined channel2.state = CHANNEL_STATES.closed @@ -687,6 +708,22 @@ describe('onConnMessage', () => { assert.strictEqual(targetSpy.callCount, 1) assert.strictEqual(otherSpy.callCount, 0) }) + + test("on heartbeat events from the 'phoenix' topic, callback is called", async () => { + let called = false + let socket = new RealtimeClient(url, { + heartbeatCallback: () => (called = true), + }) + + const message = + '{"ref":"1","event":"phx_reply","payload":{"status":"ok","response":{}},"topic":"phoenix"}' + const data = { data: message } + + socket.pendingHeartbeatRef = '3' + socket._onConnMessage(data) + + assert.strictEqual(called, true) + }) }) describe('custom encoder and decoder', () => { From 1eb5a18421818872ac8e80a614daab59a65d3dfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filipe=20Caba=C3=A7o?= Date: Mon, 7 Apr 2025 23:11:57 +0100 Subject: [PATCH 2/4] Update src/RealtimeClient.ts Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/RealtimeClient.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/RealtimeClient.ts b/src/RealtimeClient.ts index 9a75131..a6f1edb 100755 --- a/src/RealtimeClient.ts +++ b/src/RealtimeClient.ts @@ -285,7 +285,7 @@ export default class RealtimeClient { channel: RealtimeChannel ): Promise { const status = await channel.unsubscribe() - if (this.channels.values.length === 0) { + if (this.channels.size === 0) { this.disconnect() } return status From b956921e6a0c4094743846b5e9fe391636708f07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filipe=20Caba=C3=A7o?= Date: Mon, 7 Apr 2025 23:20:16 +0100 Subject: [PATCH 3/4] change callback in test to use RealtimeMessage --- test/client.test.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/client.test.ts b/test/client.test.ts index bbeeaa9..45a764e 100644 --- a/test/client.test.ts +++ b/test/client.test.ts @@ -5,7 +5,7 @@ import WebSocket from 'ws' import sinon from 'sinon' import crypto from 'crypto' -import RealtimeClient from '../src/RealtimeClient' +import RealtimeClient, { RealtimeMessage } from '../src/RealtimeClient' import jwt from 'jsonwebtoken' import { CHANNEL_STATES, LOG_LEVEL } from '../src/lib/constants' @@ -712,7 +712,9 @@ describe('onConnMessage', () => { test("on heartbeat events from the 'phoenix' topic, callback is called", async () => { let called = false let socket = new RealtimeClient(url, { - heartbeatCallback: () => (called = true), + heartbeatCallback: (message: RealtimeMessage) => { + called = message.payload.status == 'ok' + }, }) const message = From 6fed6a684d337e4e43ebeb4f56b2e429ba07b392 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filipe=20Caba=C3=A7o?= Date: Tue, 15 Apr 2025 12:13:06 +0100 Subject: [PATCH 4/4] change to have multiple status and be a onHeartbeat function; expose a Heartbeat status --- src/RealtimeClient.ts | 22 ++++++++++++++-------- test/client.test.ts | 28 ++++++++++++++++++++++------ 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/src/RealtimeClient.ts b/src/RealtimeClient.ts index a6f1edb..827d470 100755 --- a/src/RealtimeClient.ts +++ b/src/RealtimeClient.ts @@ -38,6 +38,12 @@ export type RealtimeMessage = { } export type RealtimeRemoveChannelResponse = 'ok' | 'timed out' | 'error' +export type HeartbeatStatus = + | 'sent' + | 'ok' + | 'error' + | 'timeout' + | 'disconnected' const noop = () => {} @@ -61,7 +67,6 @@ export type RealtimeClientOptions = { transport?: WebSocketLikeConstructor timeout?: number heartbeatIntervalMs?: number - heartbeatCallback?: Function logger?: Function encode?: Function decode?: Function @@ -97,7 +102,7 @@ export default class RealtimeClient { heartbeatIntervalMs: number = 25000 heartbeatTimer: ReturnType | undefined = undefined pendingHeartbeatRef: string | null = null - heartbeatCallback: Function = noop + heartbeatCallback: (status: HeartbeatStatus) => void = noop ref: number = 0 reconnectTimer: Timer logger: Function = noop @@ -135,7 +140,6 @@ export default class RealtimeClient { * @param options.params The optional params to pass when connecting. * @param options.headers The optional headers to pass when connecting. * @param options.heartbeatIntervalMs The millisec interval to send a heartbeat message. - * @param options.heartbeatCallback Function to be executed on heartbeat check. * @param options.logger The optional function for specialized logging, ie: logger: (kind, msg, data) => { console.log(`${kind}: ${msg}`, data) } * @param options.logLevel Sets the log level for Realtime * @param options.encode The function to encode outgoing messages. Defaults to JSON: (payload, callback) => callback(JSON.stringify(payload)) @@ -164,9 +168,6 @@ export default class RealtimeClient { if (options?.heartbeatIntervalMs) this.heartbeatIntervalMs = options.heartbeatIntervalMs - if (options?.heartbeatCallback) - this.heartbeatCallback = options.heartbeatCallback - const accessTokenValue = options?.params?.apikey if (accessTokenValue) { this.accessTokenValue = accessTokenValue @@ -413,6 +414,7 @@ export default class RealtimeClient { */ async sendHeartbeat() { if (!this.isConnected()) { + this.heartbeatCallback('disconnected') return } if (this.pendingHeartbeatRef) { @@ -421,6 +423,7 @@ export default class RealtimeClient { 'transport', 'heartbeat timeout. Attempting to re-establish connection' ) + this.heartbeatCallback('timeout') this.conn?.close(WS_CLOSE_NORMAL, 'hearbeat timeout') return } @@ -431,10 +434,13 @@ export default class RealtimeClient { payload: {}, ref: this.pendingHeartbeatRef, }) - + this.heartbeatCallback('sent') await this.setAuth() } + onHeartbeat(callback: (status: HeartbeatStatus) => void): void { + this.heartbeatCallback = callback + } /** * Flushes send buffer */ @@ -529,7 +535,7 @@ export default class RealtimeClient { let { topic, event, payload, ref } = msg if (topic === 'phoenix' && event === 'phx_reply') { - this.heartbeatCallback(msg) + this.heartbeatCallback(msg.payload.status == 'ok' ? 'ok' : 'error') } if (ref && ref === this.pendingHeartbeatRef) { diff --git a/test/client.test.ts b/test/client.test.ts index 45a764e..32275a9 100644 --- a/test/client.test.ts +++ b/test/client.test.ts @@ -5,7 +5,10 @@ import WebSocket from 'ws' import sinon from 'sinon' import crypto from 'crypto' -import RealtimeClient, { RealtimeMessage } from '../src/RealtimeClient' +import RealtimeClient, { + HeartbeatStatus, + RealtimeMessage, +} from '../src/RealtimeClient' import jwt from 'jsonwebtoken' import { CHANNEL_STATES, LOG_LEVEL } from '../src/lib/constants' @@ -711,11 +714,8 @@ describe('onConnMessage', () => { test("on heartbeat events from the 'phoenix' topic, callback is called", async () => { let called = false - let socket = new RealtimeClient(url, { - heartbeatCallback: (message: RealtimeMessage) => { - called = message.payload.status == 'ok' - }, - }) + let socket = new RealtimeClient(url) + socket.onHeartbeat((message: HeartbeatStatus) => (called = message == 'ok')) const message = '{"ref":"1","event":"phx_reply","payload":{"status":"ok","response":{}},"topic":"phoenix"}' @@ -724,6 +724,22 @@ describe('onConnMessage', () => { socket.pendingHeartbeatRef = '3' socket._onConnMessage(data) + assert.strictEqual(called, true) + }) + test("on heartbeat events from the 'phoenix' topic, callback is called with error", async () => { + let called = false + let socket = new RealtimeClient(url) + socket.onHeartbeat( + (message: HeartbeatStatus) => (called = message == 'error') + ) + + const message = + '{"ref":"1","event":"phx_reply","payload":{"status":"error","response":{}},"topic":"phoenix"}' + const data = { data: message } + + socket.pendingHeartbeatRef = '3' + socket._onConnMessage(data) + assert.strictEqual(called, true) }) })