diff --git a/js/src/client.ts b/js/src/client.ts index a1eb830d..adce16b1 100644 --- a/js/src/client.ts +++ b/js/src/client.ts @@ -75,6 +75,7 @@ export interface ClientConfig { hideOutputs?: boolean | ((outputs: KVMap) => KVMap); autoBatchTracing?: boolean; batchSizeBytesLimit?: number; + tracePayloadByteCompressionLimit?: number; blockOnRootRunFinalization?: boolean; traceBatchConcurrency?: number; fetchOptions?: RequestInit; @@ -364,6 +365,46 @@ const handle429 = async (response?: Response) => { return false; }; +const _compressPayload = async ( + payload: string | Uint8Array, + contentType: string +) => { + const compressedPayloadStream = new Blob([payload]) + .stream() + .pipeThrough(new CompressionStream("gzip")); + const reader = compressedPayloadStream.getReader(); + const chunks = []; + let totalLength = 0; + // eslint-disable-next-line no-constant-condition + while (true) { + const { done, value } = await reader.read(); + if (done) break; + chunks.push(value); + totalLength += value.length; + } + return new Blob(chunks, { + type: `${contentType}; length=${totalLength}; encoding=gzip`, + }); +}; + +const _preparePayload = async ( + payload: any, + contentType: string, + compressionThreshold: number +) => { + let finalPayload = payload; + // eslint-disable-next-line no-instanceof/no-instanceof + if (!(payload instanceof Uint8Array)) { + finalPayload = stringifyForTracing(payload); + } + if (finalPayload.length < compressionThreshold) { + return new Blob([finalPayload], { + type: `${contentType}; length=${finalPayload.length}`, + }); + } + return _compressPayload(finalPayload, contentType); +}; + export class AutoBatchQueue { items: { action: "create" | "update"; @@ -435,6 +476,8 @@ export class AutoBatchQueue { // 20 MB export const DEFAULT_BATCH_SIZE_LIMIT_BYTES = 20_971_520; +const DEFAULT_MAX_UNCOMPRESSED_PAYLOAD_LIMIT = 10 * 1024; + const SERVER_INFO_REQUEST_TIMEOUT = 1000; export class Client { @@ -468,6 +511,9 @@ export class Client { private autoBatchAggregationDelayMs = 250; + private tracePayloadByteCompressionLimit = + DEFAULT_MAX_UNCOMPRESSED_PAYLOAD_LIMIT; + private batchSizeBytesLimit?: number; private fetchOptions: RequestInit; @@ -520,6 +566,9 @@ export class Client { this.blockOnRootRunFinalization = config.blockOnRootRunFinalization ?? this.blockOnRootRunFinalization; this.batchSizeBytesLimit = config.batchSizeBytesLimit; + this.tracePayloadByteCompressionLimit = + config.tracePayloadByteCompressionLimit ?? + this.tracePayloadByteCompressionLimit; this.fetchOptions = config.fetchOptions || {}; } @@ -1038,6 +1087,7 @@ export class Client { delete preparedCreate.attachments; preparedCreateParams.push(preparedCreate); } + let preparedUpdateParams = []; for (const update of runUpdates ?? []) { preparedUpdateParams.push(this.prepareRunCreateOrUpdateInputs(update)); @@ -1109,24 +1159,26 @@ export class Client { originalPayload; const fields = { inputs, outputs, events }; // encode the main run payload - const stringifiedPayload = stringifyForTracing(payload); accumulatedParts.push({ name: `${method}.${payload.id}`, - payload: new Blob([stringifiedPayload], { - type: `application/json; length=${stringifiedPayload.length}`, // encoding=gzip - }), + payload: await _preparePayload( + payload, + "application/json", + this.tracePayloadByteCompressionLimit + ), }); // encode the fields we collected for (const [key, value] of Object.entries(fields)) { if (value === undefined) { continue; } - const stringifiedValue = stringifyForTracing(value); accumulatedParts.push({ name: `${method}.${payload.id}.${key}`, - payload: new Blob([stringifiedValue], { - type: `application/json; length=${stringifiedValue.length}`, - }), + payload: await _preparePayload( + value, + "application/json", + this.tracePayloadByteCompressionLimit + ), }); } // encode the attachments @@ -1147,9 +1199,11 @@ export class Client { } accumulatedParts.push({ name: `attachment.${payload.id}.${name}`, - payload: new Blob([content], { - type: `${contentType}; length=${content.byteLength}`, - }), + payload: await _preparePayload( + content, + contentType, + this.tracePayloadByteCompressionLimit + ), }); } } @@ -1170,8 +1224,7 @@ export class Client { for (const part of parts) { formData.append(part.name, part.payload); } - // Log the form data - await this.batchIngestCaller.call( + const res = await this.batchIngestCaller.call( _getFetchImplementation(), `${this.apiUrl}/runs/multipart`, { @@ -1184,15 +1237,10 @@ export class Client { ...this.fetchOptions, } ); - } catch (e) { - let errorMessage = "Failed to multipart ingest runs"; - // eslint-disable-next-line no-instanceof/no-instanceof - if (e instanceof Error) { - errorMessage += `: ${e.stack || e.message}`; - } else { - errorMessage += `: ${String(e)}`; - } - console.warn(`${errorMessage.trim()}\n\nContext: ${context}`); + await raiseForStatus(res, "ingest multipart runs", true); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + console.warn(`${e.message.trim()}\n\nContext: ${context}`); } } diff --git a/js/src/tests/batch_client.int.test.ts b/js/src/tests/batch_client.int.test.ts index 5e4126c7..a9fe4cc2 100644 --- a/js/src/tests/batch_client.int.test.ts +++ b/js/src/tests/batch_client.int.test.ts @@ -243,6 +243,59 @@ test.concurrent( 180_000 ); +test("Test persist run with all items compressed", async () => { + const langchainClient = new Client({ + autoBatchTracing: true, + callerOptions: { maxRetries: 2 }, + timeout_ms: 30_000, + tracePayloadByteCompressionLimit: 1, + }); + const projectName = "__test_compression" + uuidv4().substring(0, 4); + await deleteProject(langchainClient, projectName); + + const runId = uuidv4(); + const dottedOrder = convertToDottedOrderFormat( + new Date().getTime() / 1000, + runId + ); + const pathname = path.join( + path.dirname(fileURLToPath(import.meta.url)), + "test_data", + "parrot-icon.png" + ); + await langchainClient.createRun({ + id: runId, + project_name: projectName, + name: "test_run", + run_type: "llm", + inputs: { text: "hello world" }, + trace_id: runId, + dotted_order: dottedOrder, + attachments: { + testimage: ["image/png", new Uint8Array(fs.readFileSync(pathname))], + }, + }); + + await new Promise((resolve) => setTimeout(resolve, 1000)); + + await langchainClient.updateRun(runId, { + outputs: { output: ["Hi"] }, + dotted_order: dottedOrder, + trace_id: runId, + end_time: Math.floor(new Date().getTime() / 1000), + }); + + await Promise.all([ + waitUntilRunFound(langchainClient, runId, true), + waitUntilProjectFound(langchainClient, projectName), + ]); + + const storedRun = await langchainClient.readRun(runId); + expect(storedRun.id).toEqual(runId); + expect(storedRun.status).toEqual("success"); + // await langchainClient.deleteProject({ projectName }); +}, 180_000); + test.skip("very large runs", async () => { const langchainClient = new Client({ autoBatchTracing: true, diff --git a/js/src/tests/batch_client.test.ts b/js/src/tests/batch_client.test.ts index 43c85c15..5f7cda7d 100644 --- a/js/src/tests/batch_client.test.ts +++ b/js/src/tests/batch_client.test.ts @@ -2,6 +2,7 @@ /* eslint-disable prefer-const */ import { jest } from "@jest/globals"; import { v4 as uuidv4 } from "uuid"; +import * as zlib from "node:zlib"; import { Client, mergeRuntimeEnvIntoRunCreate } from "../client.js"; import { convertToDottedOrderFormat } from "../run_trees.js"; import { _getFetchImplementation } from "../singletons/fetch.js"; @@ -24,7 +25,17 @@ const parseMockRequestBody = async (body: string | FormData) => { try { parsedValue = JSON.parse(text); } catch (e) { - parsedValue = text; + try { + // Try decompression + const decompressed = zlib + .gunzipSync(Buffer.from(await value.arrayBuffer())) + .toString(); + parsedValue = JSON.parse(decompressed); + } catch (e) { + console.log(e); + // Give up + parsedValue = text; + } } // if (method === "attachment") { // for (const item of reconstructedBody.post) { @@ -609,12 +620,20 @@ describe.each(ENDPOINT_TYPES)( await new Promise((resolve) => setTimeout(resolve, 10)); - const calledRequestParam: any = callSpy.mock.calls[0][2]; - const calledRequestParam2: any = callSpy.mock.calls[1][2]; + const calledRequestBody = await parseMockRequestBody( + (callSpy.mock.calls[0][2] as any).body + ); + const calledRequestBody2: any = await parseMockRequestBody( + (callSpy.mock.calls[1][2] as any).body + ); // Queue should drain as soon as size limit is reached, // sending both batches - expect(await parseMockRequestBody(calledRequestParam?.body)).toEqual({ + expect( + calledRequestBody.post.length === 10 + ? calledRequestBody + : calledRequestBody2 + ).toEqual({ post: runIds.slice(0, 10).map((runId, i) => expect.objectContaining({ id: runId, @@ -628,7 +647,11 @@ describe.each(ENDPOINT_TYPES)( patch: [], }); - expect(await parseMockRequestBody(calledRequestParam2?.body)).toEqual({ + expect( + calledRequestBody.post.length === 5 + ? calledRequestBody + : calledRequestBody2 + ).toEqual({ post: runIds.slice(10).map((runId, i) => expect.objectContaining({ id: runId, @@ -784,7 +807,7 @@ describe.each(ENDPOINT_TYPES)( dotted_order: dottedOrder, }); - await new Promise((resolve) => setTimeout(resolve, 300)); + await client.awaitPendingTraceBatches(); const calledRequestParam: any = callSpy.mock.calls[0][2]; expect( @@ -903,3 +926,82 @@ describe.each(ENDPOINT_TYPES)( }); } ); + +it("should compress fields above the compression limit", async () => { + const client = new Client({ + apiKey: "test-api-key", + tracePayloadByteCompressionLimit: 1000, + autoBatchTracing: true, + }); + const callSpy = jest + .spyOn((client as any).batchIngestCaller, "call") + .mockResolvedValue({ + ok: true, + text: () => "", + }); + jest.spyOn(client as any, "_getServerInfo").mockImplementation(() => { + return { + version: "foo", + batch_ingest_config: { use_multipart_endpoint: true }, + }; + }); + + const projectName = "__test_compression"; + + const runId = uuidv4(); + const dottedOrder = convertToDottedOrderFormat( + new Date().getTime() / 1000, + runId + ); + + await client.createRun({ + id: runId, + project_name: projectName, + name: "test_run", + run_type: "llm", + inputs: { text: "hello world!" }, + trace_id: runId, + dotted_order: dottedOrder, + }); + + const runId2 = uuidv4(); + const dottedOrder2 = convertToDottedOrderFormat( + new Date().getTime() / 1000, + runId + ); + + await client.createRun({ + id: runId2, + project_name: projectName, + name: "test_run2", + run_type: "llm", + inputs: { text: `hello world!${"x".repeat(1000)}` }, + trace_id: runId2, + dotted_order: dottedOrder2, + }); + + await client.awaitPendingTraceBatches(); + + const calledRequestParam: any = callSpy.mock.calls[0][2]; + expect(await parseMockRequestBody(calledRequestParam?.body)).toEqual({ + post: [ + expect.objectContaining({ + id: runId, + run_type: "llm", + inputs: { + text: "hello world!", + }, + trace_id: runId, + }), + expect.objectContaining({ + id: runId2, + run_type: "llm", + inputs: { + text: `hello world!${"x".repeat(1000)}`, + }, + trace_id: runId2, + }), + ], + patch: [], + }); +});