From 12d2b8ea2b5061143b528c6b9fa111efaaf37d33 Mon Sep 17 00:00:00 2001 From: Oliver Klemenz Date: Thu, 17 Oct 2024 16:29:55 +0200 Subject: [PATCH] Fix support for absolute service paths --- CHANGELOG.md | 1 + src/adapter/redis.js | 8 +- src/index.js | 104 +++++++++++----------- src/socket/base.js | 9 ++ src/socket/socket.io.js | 6 +- src/socket/ws.js | 6 +- test/socketio/facade_socket.io.test.js | 2 +- test/socketio/protocols_socket.io.test.js | 3 +- test/ws/facade_ws.test.js | 2 +- test/ws/protocols_ws.test.js | 3 +- 10 files changed, 77 insertions(+), 67 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ba8990e..5cd3174 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Support for http conform headers (`x-ws` and `x-websocket`) - Revise error handling for websocket events - Fix for operations without parameters +- Fix support for absolute service paths ## Version 1.3.0 - 2024-10-07 diff --git a/src/adapter/redis.js b/src/adapter/redis.js index 348a305..916ad50 100644 --- a/src/adapter/redis.js +++ b/src/adapter/redis.js @@ -23,7 +23,7 @@ class RedisAdapter extends BaseAdapter { return; } try { - const channel = this.prefix + path; + const channel = this.getChannel(path); await this.client.subscribe(channel, async (message, messageChannel) => { try { if (messageChannel === channel) { @@ -43,12 +43,16 @@ class RedisAdapter extends BaseAdapter { return; } try { - const channel = this.prefix + path; + const channel = this.getChannel(path); await this.client.publish(channel, message); } catch (err) { LOG?.error(err); } } + + getChannel(path) { + return `${this.prefix}/${path}`; + } } module.exports = RedisAdapter; diff --git a/src/index.js b/src/index.js index 348505f..1e4d92c 100644 --- a/src/index.js +++ b/src/index.js @@ -43,49 +43,46 @@ function serveWebSocketServer(options) { } } // Websockets events - if (cds.env.protocols.websocket || cds.env.protocols.ws) { - const eventServices = {}; - for (const name in cds.model.definitions) { - const definition = cds.model.definitions[name]; - if (definition.kind === "event" && (definition["@websocket"] || definition["@ws"])) { - const service = cds.services[definition._service?.name]; - if (service && !isServedViaWebsocket(service)) { - eventServices[service.name] ??= eventServices[service.name] || { - name: service.name, - definition: service.definition, - endpoints: service.endpoints.map((endpoint) => { - const protocol = - cds.env.protocols[endpoint.kind] || - (endpoint.kind === "odata" ? cds.env.protocols["odata-v4"] : null); - return { - kind: "websocket", - path: - (cds.env.protocols.websocket?.path || cds.env.protocols.ws?.path) + - normalizeServicePath(service.path, protocol.path), - }; - }), - operations: () => { - return interableObject(); - }, - entities: () => { - return interableObject(); - }, - _events: interableObject(), - events: function () { - return this._events; - }, - on: service.on.bind(service), - tx: service.tx.bind(service), - }; - eventServices[service.name]._events[serviceLocalName(service, definition.name)] = definition; - } + const eventServices = {}; + for (const name in cds.model.definitions) { + const definition = cds.model.definitions[name]; + if (definition.kind === "event" && (definition["@websocket"] || definition["@ws"])) { + const service = cds.services[definition._service?.name]; + if (service && !isServedViaWebsocket(service)) { + eventServices[service.name] ??= eventServices[service.name] || { + name: service.name, + definition: service.definition, + endpoints: service.endpoints.map((endpoint) => { + const protocol = + cds.env.protocols[endpoint.kind] || + (endpoint.kind === "odata" ? cds.env.protocols["odata-v4"] : null); + let path = normalizeServicePath(service.path, protocol.path); + if (!path.startsWith("/")) { + path = (cds.env.protocols?.websocket?.path || cds.env.protocols?.ws?.path || "/ws") + "/" + path; + } + return { kind: "websocket", path }; + }), + operations: () => { + return interableObject(); + }, + entities: () => { + return interableObject(); + }, + _events: interableObject(), + events: function () { + return this._events; + }, + on: service.on.bind(service), + tx: service.tx.bind(service), + }; + eventServices[service.name]._events[serviceLocalName(service, definition.name)] = definition; } } - for (const name in eventServices) { - const eventService = eventServices[name]; - if (Object.keys(eventService.events()).length > 0) { - serveWebSocketService(socketServer, eventService, options); - } + } + for (const name in eventServices) { + const eventService = eventServices[name]; + if (Object.keys(eventService.events()).length > 0) { + serveWebSocketService(socketServer, eventService, options); } } LOG?.info("using websocket", { kind: cds.env.websocket.kind, adapter: socketServer.adapterActive }); @@ -112,8 +109,8 @@ async function initWebSocketServer(server, path) { } function normalizeServicePath(servicePath, protocolPath) { - if (servicePath.startsWith(protocolPath)) { - return servicePath.substring(protocolPath.length); + if (servicePath.startsWith(`${protocolPath}/`)) { + return servicePath.substring(`${protocolPath}/`.length); } return servicePath; } @@ -152,11 +149,11 @@ function bindServiceEvents(socketServer, service, path) { const user = deriveUser(event, req.data, headers, req); const context = deriveContext(event, req.data, headers); const identifier = deriveIdentifier(event, req.data, headers); - const eventHeaders = headers?.websocket || headers?.ws ? { ...headers?.websocket, ...headers?.ws } : undefined; - path = normalizeEventPath(event["@websocket.path"] || event["@ws.path"] || path); + const eventHeaders = deriveEventHeaders(headers); + const eventPath = derivePath(event, path); await socketServer.broadcast({ service, - path, + path: eventPath, event: localEventName, data: req.data, tenant: req.tenant, @@ -702,6 +699,14 @@ function deriveHeaders(headers, format) { return headers; } +function deriveEventHeaders(headers) { + return headers?.websocket || headers?.ws ? { ...headers?.websocket, ...headers?.ws } : undefined; +} + +function derivePath(event, path) { + return event["@websocket.path"] || event["@ws.path"] || path; +} + function getDeepEntityColumns(entity) { const columns = []; for (const element of Object.values(entity.elements)) { @@ -719,13 +724,6 @@ function getDeepEntityColumns(entity) { return columns; } -function normalizeEventPath(path) { - if (!path) { - return path; - } - return path.startsWith("/") ? path : `/${path}`; -} - function serviceLocalName(service, name) { const servicePrefix = `${service.name}.`; if (name.startsWith(servicePrefix)) { diff --git a/src/socket/base.js b/src/socket/base.js index 6836d38..7b21acc 100644 --- a/src/socket/base.js +++ b/src/socket/base.js @@ -370,6 +370,15 @@ class SocketServer { return require(impl); } + /** + * Return service path including protocol prefix or absolute service path (if already absolute) + * @param {String} path path + * @returns {String} Service path + */ + servicePath(path) { + return path.startsWith("/") ? path : `${this.path}/${path}`; + } + /** * Return format instance for service * @param {Object} service Service definition diff --git a/src/socket/socket.io.js b/src/socket/socket.io.js index b79f57b..00a9f41 100644 --- a/src/socket/socket.io.js +++ b/src/socket/socket.io.js @@ -26,8 +26,7 @@ class SocketIOServer extends SocketServer { } service(service, path, connected) { - const servicePath = `${this.path}${path}`; - const io = this.applyMiddlewares(this.io.of(servicePath)); + const io = this.applyMiddlewares(this.io.of(this.servicePath(path))); const format = this.format(service, undefined, "json"); io.on("connection", async (socket) => { try { @@ -175,8 +174,7 @@ class SocketIOServer extends SocketServer { try { path = path || this.defaultPath(service); tenant = tenant || socket?.context.tenant; - const servicePath = `${this.path}${path}`; - let to = socket?.broadcast || this.io.of(servicePath); + let to = socket?.broadcast || this.io.of(this.servicePath(path)); if (context?.include?.length && identifier?.include?.length) { for (const contextInclude of context.include) { for (const identifierInclude of identifier.include) { diff --git a/src/socket/ws.js b/src/socket/ws.js index 25311f0..71e3077 100644 --- a/src/socket/ws.js +++ b/src/socket/ws.js @@ -38,9 +38,8 @@ class SocketWSServer extends SocketServer { service(service, path, connected) { this.adapter?.on(service, path); - const servicePath = `${this.path}${path}`; const format = this.format(service); - this.services[servicePath] = (ws, request) => { + this.services[this.servicePath(path)] = (ws, request) => { this.onInit(ws, request); DEBUG?.("Initialized"); ws.on("close", () => { @@ -154,8 +153,7 @@ class SocketWSServer extends SocketServer { } path = path || this.defaultPath(service); tenant = tenant || socket?.context.tenant; - const servicePath = `${this.path}${path}`; - const serviceClients = this.fetchClients(tenant, servicePath); + const serviceClients = this.fetchClients(tenant, this.servicePath(path)); const clients = new Set(serviceClients.all); if (user?.include?.length) { this.keepEntriesFromSet(clients, this.collectFromMap(serviceClients.users, user?.include)); diff --git a/test/socketio/facade_socket.io.test.js b/test/socketio/facade_socket.io.test.js index f6a9757..becb59e 100644 --- a/test/socketio/facade_socket.io.test.js +++ b/test/socketio/facade_socket.io.test.js @@ -26,7 +26,7 @@ describe("Facade", () => { const facade = socket.serverSocket.facade; expect(facade).toBeDefined(); expect(facade.service).toEqual(expect.any(Object)); - expect(facade.path).toEqual("/chat"); + expect(facade.path).toEqual("chat"); expect(facade.socket).toBeDefined(); const context = facade.context; expect(context).toBeDefined(); diff --git a/test/socketio/protocols_socket.io.test.js b/test/socketio/protocols_socket.io.test.js index f86331e..35fd0ca 100644 --- a/test/socketio/protocols_socket.io.test.js +++ b/test/socketio/protocols_socket.io.test.js @@ -25,7 +25,8 @@ describe("Protocols", () => { }); test.each(protocols)("Protocol - %p", async (protocol) => { - const socket = await connect(`/ws/${protocol}`); + const path = (protocol.includes("absolute") ? "/" : "/ws/") + protocol; + const socket = await connect(path); const waitProtocol = waitForEvent(socket, "test"); await emitEvent(socket, "trigger", { text: protocol }); const waitResult = await waitProtocol; diff --git a/test/ws/facade_ws.test.js b/test/ws/facade_ws.test.js index 1da4561..f3dd867 100644 --- a/test/ws/facade_ws.test.js +++ b/test/ws/facade_ws.test.js @@ -24,7 +24,7 @@ describe("Facade", () => { const facade = socket.serverSocket.facade; expect(facade).toBeDefined(); expect(facade.service).toEqual(expect.any(Object)); - expect(facade.path).toEqual("/chat"); + expect(facade.path).toEqual("chat"); expect(facade.socket).toBeDefined(); const context = facade.context; expect(context).toBeDefined(); diff --git a/test/ws/protocols_ws.test.js b/test/ws/protocols_ws.test.js index 08dda83..182b07a 100644 --- a/test/ws/protocols_ws.test.js +++ b/test/ws/protocols_ws.test.js @@ -23,7 +23,8 @@ describe("Protocols", () => { }); test.each(protocols)("Protocol - %p", async (protocol) => { - const socket = await connect("/ws/" + protocol); + const path = (protocol.includes("absolute") ? "/" : "/ws/") + protocol; + const socket = await connect(path); const waitProtocol = waitForEvent(socket, "test"); await emitEvent(socket, "trigger", { text: protocol }); const waitResult = await waitProtocol;