From 4ce8069850d6560d2ed0b0895fe66306c432bf9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Halvard=20M=C3=B8rstad?= Date: Mon, 15 Apr 2024 11:48:12 +0200 Subject: [PATCH] Refactored client and improved connection --- lib/client.test.ts | 50 +-- lib/client.ts | 522 +++++++++++++++++++++--------- lib/connection.test.ts | 252 ++++++++++++--- lib/connection.ts | 499 +++++++++++++++++++--------- lib/connection2.ts | 590 ---------------------------------- lib/packets/builders/query.ts | 6 +- lib/packets/parsers/err.ts | 4 +- lib/packets/parsers/result.ts | 22 +- lib/utils/logger.ts | 19 +- lib/utils/query.ts | 15 +- 10 files changed, 1005 insertions(+), 974 deletions(-) delete mode 100644 lib/connection2.ts diff --git a/lib/client.test.ts b/lib/client.test.ts index f861732..2508467 100644 --- a/lib/client.test.ts +++ b/lib/client.test.ts @@ -1,25 +1,27 @@ -// import { MysqlClient } from "./client.ts"; -// import { URL_TEST_CONNECTION } from "./utils/testing.ts"; -// import { implementationTest } from "@halvardm/sqlx/testing"; +import { MysqlClient } from "./client.ts"; +import { URL_TEST_CONNECTION } from "./utils/testing.ts"; +import { implementationTest } from "@halvardm/sqlx/testing"; -// Deno.test("MysqlClient", async (t) => { -// await implementationTest({ -// t, -// Client: MysqlClient, -// connectionUrl: URL_TEST_CONNECTION, -// connectionOptions: {}, -// queries:{ -// createTable: "CREATE TABLE IF NOT EXISTS sqlxtesttable (testcol TEXT)", -// dropTable: "DROP TABLE IF EXISTS sqlxtesttable", -// insertOneToTable: "INSERT INTO sqlxtesttable (testcol) VALUES (?)", -// insertManyToTable: "INSERT INTO sqlxtesttable (testcol) VALUES (?),(?),(?)", -// selectOneFromTable: "SELECT * FROM sqlxtesttable WHERE testcol = ? LIMIT 1", -// selectByMatchFromTable: "SELECT * FROM sqlxtesttable WHERE testcol = ?", -// selectManyFromTable: "SELECT * FROM sqlxtesttable", -// select1AsString: "SELECT '1' as result", -// select1Plus1AsNumber: "SELECT 1+1 as result", -// deleteByMatchFromTable: "DELETE FROM sqlxtesttable WHERE testcol = ?", -// deleteAllFromTable: "DELETE FROM sqlxtesttable", -// } -// }); -// }); +Deno.test("MysqlClient", async (t) => { + await implementationTest({ + t, + Client: MysqlClient, + connectionUrl: URL_TEST_CONNECTION, + connectionOptions: {}, + queries: { + createTable: "CREATE TABLE IF NOT EXISTS sqlxtesttable (testcol TEXT)", + dropTable: "DROP TABLE IF EXISTS sqlxtesttable", + insertOneToTable: "INSERT INTO sqlxtesttable (testcol) VALUES (?)", + insertManyToTable: + "INSERT INTO sqlxtesttable (testcol) VALUES (?),(?),(?)", + selectOneFromTable: + "SELECT * FROM sqlxtesttable WHERE testcol = ? LIMIT 1", + selectByMatchFromTable: "SELECT * FROM sqlxtesttable WHERE testcol = ?", + selectManyFromTable: "SELECT * FROM sqlxtesttable", + select1AsString: "SELECT '1' as result", + select1Plus1AsNumber: "SELECT 1+1 as result", + deleteByMatchFromTable: "DELETE FROM sqlxtesttable WHERE testcol = ?", + deleteAllFromTable: "DELETE FROM sqlxtesttable", + }, + }); +}); diff --git a/lib/client.ts b/lib/client.ts index a6de1f4..6779730 100644 --- a/lib/client.ts +++ b/lib/client.ts @@ -1,170 +1,392 @@ import { - type Connection, - ConnectionState, - type ExecuteResult, -} from "./connection.ts"; -import { ConnectionPool, PoolConnection } from "./pool.ts"; -import { logger } from "./utils/logger.ts"; -import { MysqlError } from "./utils/errors.ts"; + type ArrayRow, + type Row, + type SqlxConnection, + SqlxConnectionCloseEvent, + SqlxConnectionConnectEvent, + type SqlxConnectionEventType, + type SqlxPreparable, + type SqlxPreparedQueriable, + type SqlxQueriable, + type SqlxQueryOptions, + type SqlxTransactionable, + type SqlxTransactionOptions, + type SqlxTransactionQueriable, + VERSION, +} from "@halvardm/sqlx"; +import { MysqlConnection, type MysqlConnectionOptions } from "./connection.ts"; +import { buildQuery } from "./packets/builders/query.ts"; +import { + getRowObject, + type MysqlParameterType, +} from "./packets/parsers/result.ts"; -/** - * Client Config - */ -export interface ClientConfig { - /** Database hostname */ - hostname?: string; - /** Database UNIX domain socket path. When used, `hostname` and `port` are ignored. */ - socketPath?: string; - /** Database username */ - username?: string; - /** Database password */ - password?: string; - /** Database port */ - port?: number; - /** Database name */ - db?: string; - /** Whether to display packet debugging information */ - debug?: boolean; - /** Connection read timeout (default: 30 seconds) */ - timeout?: number; - /** Connection pool size (default: 1) */ - poolSize?: number; - /** Connection pool idle timeout in microseconds (default: 4 hours) */ - idleTimeout?: number; - /** charset */ - charset?: string; - /** tls config */ - tls?: TLSConfig; +export interface MysqlTransactionOptions extends SqlxTransactionOptions { + beginTransactionOptions: { + withConsistentSnapshot?: boolean; + readWrite?: "READ WRITE" | "READ ONLY"; + }; + commitTransactionOptions: { + chain?: boolean; + release?: boolean; + }; + rollbackTransactionOptions: { + chain?: boolean; + release?: boolean; + savepoint?: string; + }; +} + +export interface MysqlClientOptions extends MysqlConnectionOptions { } -export enum TLSMode { - DISABLED = "disabled", - VERIFY_IDENTITY = "verify_identity", +export interface MysqlQueryOptions extends SqlxQueryOptions { } + /** - * TLS Config + * Prepared statement + * + * @todo implement prepared statements properly */ -export interface TLSConfig { - /** mode of tls. only support disabled and verify_identity now*/ - mode?: TLSMode; - /** A list of root certificates (must be PEM format) that will be used in addition to the - * default root certificates to verify the peer's certificate. */ - caCerts?: string[]; +export class MysqlPrepared + implements SqlxPreparedQueriable { + readonly sqlxVersion = VERSION; + readonly queryOptions: MysqlQueryOptions; + + #sql: string; + + #queriable: MysqlQueriable; + + constructor( + connection: MysqlConnection, + sql: string, + options: MysqlQueryOptions = {}, + ) { + this.#queriable = new MysqlQueriable(connection); + this.#sql = sql; + this.queryOptions = options; + } + + execute( + params?: MysqlParameterType[] | undefined, + _options?: MysqlQueryOptions | undefined, + ): Promise { + return this.#queriable.execute(this.#sql, params); + } + query = Row>( + params?: MysqlParameterType[] | undefined, + options?: MysqlQueryOptions | undefined, + ): Promise { + return this.#queriable.query(this.#sql, params, options); + } + queryOne = Row>( + params?: MysqlParameterType[] | undefined, + options?: MysqlQueryOptions | undefined, + ): Promise { + return this.#queriable.queryOne(this.#sql, params, options); + } + queryMany = Row>( + params?: MysqlParameterType[] | undefined, + options?: MysqlQueryOptions | undefined, + ): AsyncIterableIterator { + return this.#queriable.queryMany(this.#sql, params, options); + } + queryArray< + T extends ArrayRow = ArrayRow, + >( + params?: MysqlParameterType[] | undefined, + options?: MysqlQueryOptions | undefined, + ): Promise { + return this.#queriable.queryArray(this.#sql, params, options); + } + queryOneArray< + T extends ArrayRow = ArrayRow, + >( + params?: MysqlParameterType[] | undefined, + options?: MysqlQueryOptions | undefined, + ): Promise { + return this.#queriable.queryOneArray(this.#sql, params, options); + } + queryManyArray< + T extends ArrayRow = ArrayRow, + >( + params?: MysqlParameterType[] | undefined, + options?: MysqlQueryOptions | undefined, + ): AsyncIterableIterator { + return this.#queriable.queryManyArray(this.#sql, params, options); + } +} + +export class MysqlQueriable + implements SqlxQueriable { + protected readonly connection: MysqlConnection; + readonly queryOptions: MysqlQueryOptions; + readonly sqlxVersion: string = VERSION; + + constructor( + connection: MysqlConnection, + queryOptions: MysqlQueryOptions = {}, + ) { + this.connection = connection; + this.queryOptions = queryOptions; + } + + execute( + sql: string, + params?: MysqlParameterType[] | undefined, + _options?: MysqlQueryOptions | undefined, + ): Promise { + const data = buildQuery(sql, params); + return this.connection.executeRaw(data); + } + query = Row>( + sql: string, + params?: MysqlParameterType[] | undefined, + options?: MysqlQueryOptions | undefined, + ): Promise { + return Array.fromAsync(this.queryMany(sql, params, options)); + } + async queryOne = Row>( + sql: string, + params?: MysqlParameterType[] | undefined, + options?: MysqlQueryOptions | undefined, + ): Promise { + const res = await this.query(sql, params, options); + return res[0]; + } + async *queryMany = Row>( + sql: string, + params?: MysqlParameterType[], + options?: MysqlQueryOptions | undefined, + ): AsyncGenerator { + const data = buildQuery(sql, params); + for await ( + const res of this.connection.queryManyObjectRaw(data, options) + ) { + yield res; + } + } + + queryArray< + T extends ArrayRow = ArrayRow, + >( + sql: string, + params?: MysqlParameterType[] | undefined, + options?: MysqlQueryOptions | undefined, + ): Promise { + return Array.fromAsync(this.queryManyArray(sql, params, options)); + } + async queryOneArray< + T extends ArrayRow = ArrayRow, + >( + sql: string, + params?: MysqlParameterType[] | undefined, + options?: MysqlQueryOptions | undefined, + ): Promise { + const res = await this.queryArray(sql, params, options); + return res[0]; + } + async *queryManyArray< + T extends ArrayRow = ArrayRow, + >( + sql: string, + params?: MysqlParameterType[] | undefined, + options?: MysqlQueryOptions | undefined, + ): AsyncIterableIterator { + const data = buildQuery(sql, params); + for await ( + const res of this.connection.queryManyArrayRaw(data, options) + ) { + yield res; + } + } + sql = Row>( + strings: TemplateStringsArray, + ...parameters: MysqlParameterType[] + ): Promise { + return this.query(strings.join("?"), parameters); + } + sqlArray< + T extends ArrayRow = ArrayRow, + >( + strings: TemplateStringsArray, + ...parameters: MysqlParameterType[] + ): Promise { + return this.queryArray(strings.join("?"), parameters); + } +} + +export class MysqlPreparable extends MysqlQueriable + implements + SqlxPreparable { + prepare(sql: string, options?: MysqlQueryOptions | undefined): MysqlPrepared { + return new MysqlPrepared(this.connection, sql, options); + } } -/** Transaction processor */ -export interface TransactionProcessor { - (connection: Connection): Promise; +export class MySqlTransaction extends MysqlPreparable + implements + SqlxTransactionQueriable< + MysqlParameterType, + MysqlQueryOptions, + MysqlTransactionOptions + > { + async commitTransaction( + options?: MysqlTransactionOptions["commitTransactionOptions"], + ): Promise { + let sql = "COMMIT"; + + if (options?.chain === true) { + sql += " AND CHAIN"; + } else if (options?.chain === false) { + sql += " AND NO CHAIN"; + } + + if (options?.release === true) { + sql += " RELEASE"; + } else if (options?.release === false) { + sql += " NO RELEASE"; + } + await this.execute(sql); + } + async rollbackTransaction( + options?: MysqlTransactionOptions["rollbackTransactionOptions"], + ): Promise { + let sql = "ROLLBACK"; + + if (options?.savepoint) { + sql += ` TO ${options.savepoint}`; + await this.execute(sql); + return; + } + + if (options?.chain === true) { + sql += " AND CHAIN"; + } else if (options?.chain === false) { + sql += " AND NO CHAIN"; + } + + if (options?.release === true) { + sql += " RELEASE"; + } else if (options?.release === false) { + sql += " NO RELEASE"; + } + + await this.execute(sql); + } + async createSavepoint(name: string = `\t_bm.\t`): Promise { + await this.execute(`SAVEPOINT ${name}`); + } + async releaseSavepoint(name: string = `\t_bm.\t`): Promise { + await this.execute(`RELEASE SAVEPOINT ${name}`); + } } /** - * MySQL client + * Represents a queriable class that can be used to run transactions. */ -export class Client { - config: ClientConfig = {}; - private _pool?: ConnectionPool; - - private async createConnection(): Promise { - let connection = new PoolConnection(this.config); - await connection.connect(); - return connection; - } - - /** get pool info */ - get pool() { - return this._pool?.info; - } - - /** - * connect to database - * @param config config for client - * @returns Client instance - */ - async connect(config: ClientConfig): Promise { - this.config = { - hostname: "127.0.0.1", - username: "root", - port: 3306, - poolSize: 1, - timeout: 30 * 1000, - idleTimeout: 4 * 3600 * 1000, - ...config, - }; - Object.freeze(this.config); - this._pool = new ConnectionPool( - this.config.poolSize || 10, - this.createConnection.bind(this), - ); - return this; - } - - /** - * execute query sql - * @param sql query sql string - * @param params query params - */ - async query(sql: string, params?: any[]): Promise { - return await this.useConnection(async (connection) => { - return await connection.query(sql, params); - }); - } - - /** - * execute sql - * @param sql sql string - * @param params query params - */ - async execute(sql: string, params?: any[]): Promise { - return await this.useConnection(async (connection) => { - return await connection.execute(sql, params); - }); - } - - async useConnection(fn: (conn: Connection) => Promise) { - if (!this._pool) { - throw new MysqlError("Unconnected"); +export class MysqlTransactionable extends MysqlPreparable + implements + SqlxTransactionable< + MysqlParameterType, + MysqlQueryOptions, + MysqlTransactionOptions, + MySqlTransaction + > { + async beginTransaction( + options?: MysqlTransactionOptions["beginTransactionOptions"], + ): Promise { + let sql = "START TRANSACTION"; + if (options?.withConsistentSnapshot) { + sql += ` WITH CONSISTENT SNAPSHOT`; } - const connection = await this._pool.pop(); - try { - return await fn(connection); - } finally { - if (connection.state == ConnectionState.CLOSED) { - connection.removeFromPool(); - } else { - connection.returnToPool(); - } + + if (options?.readWrite) { + sql += ` ${options.readWrite}`; } + + await this.execute(sql); + + return new MySqlTransaction(this.connection, this.queryOptions); } - /** - * Execute a transaction process, and the transaction successfully - * returns the return value of the transaction process - * @param processor transation processor - */ - async transaction(processor: TransactionProcessor): Promise { - return await this.useConnection(async (connection) => { - try { - await connection.execute("BEGIN"); - const result = await processor(connection); - await connection.execute("COMMIT"); - return result; - } catch (error) { - if (connection.state == ConnectionState.CONNECTED) { - logger().info(`ROLLBACK: ${error.message}`); - await connection.execute("ROLLBACK"); - } - throw error; - } - }); - } - - /** - * close connection - */ - async close() { - if (this._pool) { - this._pool.close(); - this._pool = undefined; + async transaction( + fn: (t: MySqlTransaction) => Promise, + options?: MysqlTransactionOptions, + ): Promise { + const transaction = await this.beginTransaction( + options?.beginTransactionOptions, + ); + + try { + const result = await fn(transaction); + await transaction.commitTransaction(options?.commitTransactionOptions); + return result; + } catch (error) { + await transaction.rollbackTransaction( + options?.rollbackTransactionOptions, + ); + throw error; } } } + +/** + * MySQL client + */ +export class MysqlClient extends MysqlTransactionable implements + SqlxConnection< + MysqlParameterType, + MysqlQueryOptions, + MysqlPrepared, + MysqlTransactionOptions, + MySqlTransaction, + SqlxConnectionEventType, + MysqlConnectionOptions + > { + readonly connectionUrl: string; + readonly connectionOptions: MysqlConnectionOptions; + readonly eventTarget: EventTarget; + get connected(): boolean { + throw new Error("Method not implemented."); + } + + constructor( + connectionUrl: string | URL, + connectionOptions: MysqlClientOptions = {}, + ) { + const conn = new MysqlConnection(connectionUrl, connectionOptions); + super(conn); + this.connectionUrl = conn.connectionUrl; + this.connectionOptions = conn.connectionOptions; + this.eventTarget = new EventTarget(); + } + async connect(): Promise { + await this.connection.connect(); + this.dispatchEvent(new SqlxConnectionConnectEvent()); + } + async close(): Promise { + this.dispatchEvent(new SqlxConnectionCloseEvent()); + await this.connection.close(); + } + async [Symbol.asyncDispose](): Promise { + await this.close(); + } + addEventListener( + type: SqlxConnectionEventType, + listener: EventListenerOrEventListenerObject | null, + options?: boolean | AddEventListenerOptions, + ): void { + this.eventTarget.addEventListener(type, listener, options); + } + removeEventListener( + type: SqlxConnectionEventType, + callback: EventListenerOrEventListenerObject | null, + options?: boolean | EventListenerOptions, + ): void { + this.eventTarget.removeEventListener(type, callback, options); + } + dispatchEvent(event: Event): boolean { + return this.eventTarget.dispatchEvent(event); + } +} diff --git a/lib/connection.test.ts b/lib/connection.test.ts index be25f0a..bf15dc9 100644 --- a/lib/connection.test.ts +++ b/lib/connection.test.ts @@ -1,9 +1,10 @@ import { assertEquals, assertInstanceOf } from "@std/assert"; import { emptyDir } from "@std/fs"; import { join } from "@std/path"; -import { MysqlConnection } from "./connection2.ts"; +import { MysqlConnection } from "./connection.ts"; import { DIR_TMP_TEST } from "./utils/testing.ts"; import { buildQuery } from "./packets/builders/query.ts"; +import { URL_TEST_CONNECTION } from "./utils/testing.ts"; Deno.test("Connection", async (t) => { await emptyDir(DIR_TMP_TEST); @@ -19,10 +20,10 @@ Deno.test("Connection", async (t) => { await Deno.writeTextFile(PATH_PEM_KEY, "key"); await t.step("can construct", async (t) => { - const connection = new MysqlConnection("mysql://127.0.0.1:3306"); + const connection = new MysqlConnection(URL_TEST_CONNECTION); assertInstanceOf(connection, MysqlConnection); - assertEquals(connection.connectionUrl, "mysql://127.0.0.1:3306"); + assertEquals(connection.connectionUrl, URL_TEST_CONNECTION); await t.step("can parse connection config simple", () => { const url = new URL("mysql://user:pass@127.0.0.1:3306/db"); @@ -126,7 +127,7 @@ Deno.test("Connection", async (t) => { }); }); - const connection = new MysqlConnection("mysql://root@0.0.0.0:3306"); + const connection = new MysqlConnection(URL_TEST_CONNECTION); assertEquals(connection.connected, false); await t.step("can connect and close", async () => { @@ -144,55 +145,218 @@ Deno.test("Connection", async (t) => { }); await t.step("can connect with using and dispose", async () => { - await using connection = new MysqlConnection("mysql://root@0.0.0.0:3306"); + await using connection = new MysqlConnection(URL_TEST_CONNECTION); assertEquals(connection.connected, false); await connection.connect(); assertEquals(connection.connected, true); }); - await t.step("can execute", async (t) => { - await using connection = new MysqlConnection("mysql://root@0.0.0.0:3306"); - await connection.connect(); - const data = buildQuery("SELECT 1+1 AS result"); - const result = await connection.execute(data); - assertEquals(result, { affectedRows: 0, lastInsertId: null }); - }); + // await t.step("can execute", async (t) => { + // await using connection = new MysqlConnection(URL_TEST_CONNECTION); + // await connection.connect(); + // const data = buildQuery("SELECT 1+1 AS result"); + // const result = await connection.execute(data); + // assertEquals(result, { affectedRows: 0, lastInsertId: null }); + // }); - await t.step("can execute twice", async (t) => { - await using connection = new MysqlConnection("mysql://root@0.0.0.0:3306"); - await connection.connect(); - const data = buildQuery("SELECT 1+1 AS result;"); - const result1 = await connection.execute(data); - assertEquals(result1, { affectedRows: 0, lastInsertId: null }); - const result2 = await connection.execute(data); - assertEquals(result2, { affectedRows: 0, lastInsertId: null }); - }); + // await t.step("can execute twice", async (t) => { + // await using connection = new MysqlConnection(URL_TEST_CONNECTION); + // await connection.connect(); + // const data = buildQuery("SELECT 1+1 AS result;"); + // const result1 = await connection.execute(data); + // assertEquals(result1, { affectedRows: 0, lastInsertId: null }); + // const result2 = await connection.execute(data); + // assertEquals(result2, { affectedRows: 0, lastInsertId: null }); + // }); - await t.step("can sendData", async (t) => { - await using connection = new MysqlConnection("mysql://root@0.0.0.0:3306"); + await t.step("can query database", async (t) => { + await using connection = new MysqlConnection(URL_TEST_CONNECTION); await connection.connect(); - const data = buildQuery("SELECT 1+1 AS result;"); - for await (const result1 of connection.sendData(data)) { - assertEquals(result1, { - row: [2], - fields: [ - { - catalog: "def", - decimals: 0, - defaultVal: "", - encoding: 63, - fieldFlag: 129, - fieldLen: 3, - fieldType: 8, - name: "result", - originName: "", - originTable: "", - schema: "", - table: "", - }, - ], + await t.step("can sendData", async () => { + const data = buildQuery("SELECT 1+1 AS result;"); + for await (const result1 of connection.sendData(data)) { + assertEquals(result1, { + row: [2], + fields: [ + { + catalog: "def", + decimals: 0, + defaultVal: "", + encoding: 63, + fieldFlag: 129, + fieldLen: 3, + fieldType: 8, + name: "result", + originName: "", + originTable: "", + schema: "", + table: "", + }, + ], + }); + } + }); + + await t.step("can drop and create table", async () => { + const dropTableSql = buildQuery("DROP TABLE IF EXISTS test;"); + const dropTableReturned = connection.sendData(dropTableSql); + assertEquals(await dropTableReturned.next(), { + done: true, + value: { affectedRows: 0, lastInsertId: 0 }, + }); + const createTableSql = buildQuery( + "CREATE TABLE IF NOT EXISTS test (id INT);", + ); + const createTableReturned = connection.sendData(createTableSql); + assertEquals(await createTableReturned.next(), { + done: true, + value: { affectedRows: 0, lastInsertId: 0 }, + }); + const result = await Array.fromAsync(createTableReturned); + assertEquals(result, []); + }); + + await t.step("can insert to table", async () => { + const data = buildQuery("INSERT INTO test (id) VALUES (1),(2),(3);"); + const returned = connection.sendData(data); + assertEquals(await returned.next(), { + done: true, + value: { affectedRows: 3, lastInsertId: 0 }, + }); + const result = await Array.fromAsync(returned); + assertEquals(result, []); + }); + + await t.step("can select from table using sendData", async () => { + const data = buildQuery("SELECT * FROM test;"); + const returned = connection.sendData(data); + const result = await Array.fromAsync(returned); + assertEquals(result, [ + { + fields: [ + { + catalog: "def", + decimals: 0, + defaultVal: "", + encoding: 63, + fieldFlag: 0, + fieldLen: 11, + fieldType: 3, + name: "id", + originName: "id", + originTable: "test", + schema: "testdb", + table: "test", + }, + ], + row: [ + 1, + ], + }, + { + fields: [ + { + catalog: "def", + decimals: 0, + defaultVal: "", + encoding: 63, + fieldFlag: 0, + fieldLen: 11, + fieldType: 3, + name: "id", + originName: "id", + originTable: "test", + schema: "testdb", + table: "test", + }, + ], + row: [ + 2, + ], + }, + { + fields: [ + { + catalog: "def", + decimals: 0, + defaultVal: "", + encoding: 63, + fieldFlag: 0, + fieldLen: 11, + fieldType: 3, + name: "id", + originName: "id", + originTable: "test", + schema: "testdb", + table: "test", + }, + ], + row: [ + 3, + ], + }, + ]); + }); + + await t.step("can insert to table using executeRaw", async () => { + const data = buildQuery("INSERT INTO test (id) VALUES (4);"); + const result = await connection.executeRaw(data); + assertEquals(result, 1); + }); + + await t.step("can select from table using executeRaw", async () => { + const data = buildQuery("SELECT * FROM test;"); + const result = await connection.executeRaw(data); + assertEquals(result, undefined); + }); + + await t.step("can insert to table using queryManyObjectRaw", async () => { + const data = buildQuery("INSERT INTO test (id) VALUES (5);"); + const result = await Array.fromAsync(connection.queryManyObjectRaw(data)); + assertEquals(result, []); + }); + + await t.step("can select from table using queryManyObjectRaw", async () => { + const data = buildQuery("SELECT * FROM test;"); + const result = await Array.fromAsync(connection.queryManyObjectRaw(data)); + assertEquals(result, [ + { id: 1 }, + { id: 2 }, + { id: 3 }, + { id: 4 }, + { id: 5 }, + ]); + }); + + await t.step("can insert to table using queryManyArrayRaw", async () => { + const data = buildQuery("INSERT INTO test (id) VALUES (6);"); + const result = await Array.fromAsync(connection.queryManyArrayRaw(data)); + assertEquals(result, []); + }); + + await t.step("can select from table using queryManyArrayRaw", async () => { + const data = buildQuery("SELECT * FROM test;"); + const result = await Array.fromAsync(connection.queryManyArrayRaw(data)); + assertEquals(result, [ + [1], + [2], + [3], + [4], + [5], + [6], + ]); + }); + + await t.step("can drop table", async () => { + const data = buildQuery("DROP TABLE IF EXISTS test;"); + const returned = connection.sendData(data); + assertEquals(await returned.next(), { + done: true, + value: { affectedRows: 0, lastInsertId: 0 }, }); - } + const result = await Array.fromAsync(returned); + assertEquals(result, []); + }); }); await emptyDir(DIR_TMP_TEST); diff --git a/lib/connection.ts b/lib/connection.ts index 1d375d6..8c361e9 100644 --- a/lib/connection.ts +++ b/lib/connection.ts @@ -1,4 +1,3 @@ -import { type ClientConfig, TLSMode } from "./client.ts"; import { MysqlConnectionError, MysqlError, @@ -7,7 +6,6 @@ import { MysqlResponseTimeoutError, } from "./utils/errors.ts"; import { buildAuth } from "./packets/builders/auth.ts"; -import { buildQuery } from "./packets/builders/query.ts"; import { PacketReader, PacketWriter } from "./packets/packet.ts"; import { parseError } from "./packets/parsers/err.ts"; import { @@ -16,17 +14,31 @@ import { parseHandshake, } from "./packets/parsers/handshake.ts"; import { + ConvertTypeOptions, type FieldInfo, + getRowObject, + type MysqlParameterType, parseField, - parseRowObject, + parseRowArray, } from "./packets/parsers/result.ts"; import { ComQueryResponsePacket } from "./constant/packet.ts"; -import { AuthPluginName, AuthPlugins } from "./auth_plugins/mod.ts"; +import { AuthPlugins } from "./auth_plugins/mod.ts"; import { parseAuthSwitch } from "./packets/parsers/authswitch.ts"; import auth from "./utils/hash.ts"; import { ServerCapabilities } from "./constant/capabilities.ts"; import { buildSSLRequest } from "./packets/builders/tls.ts"; import { logger } from "./utils/logger.ts"; +import type { + ArrayRow, + Row, + SqlxConnectable, + SqlxConnectionOptions, +} from "@halvardm/sqlx"; +import { VERSION } from "./utils/meta.ts"; +import { resolve } from "@std/path"; +import { toCamelCase } from "@std/text"; +import { AuthPluginName } from "./auth_plugins/mod.ts"; +import type { MysqlQueryOptions } from "./client.ts"; /** * Connection state @@ -38,19 +50,91 @@ export enum ConnectionState { CLOSED, } +export type ConnectionSendDataNext = { + row: ArrayRow; + fields: FieldInfo[]; +}; +export type ConnectionSendDataResult = { + affectedRows: number | undefined; + lastInsertId: number | undefined; +}; + /** - * Result for execute sql + * Tls mode for mysql connection + * + * @see {@link https://dev.mysql.com/doc/refman/8.0/en/connection-options.html#option_general_ssl-mode} */ -export type ExecuteResult = { - affectedRows?: number; - lastInsertId?: number; - fields?: FieldInfo[]; - rows?: any[]; - iterator?: any; -}; +export const TlsMode = { + Preferred: "PREFERRED", + Disabled: "DISABLED", + Required: "REQUIRED", + VerifyCa: "VERIFY_CA", + VerifyIdentity: "VERIFY_IDENTITY", +} as const; +export type TlsMode = typeof TlsMode[keyof typeof TlsMode]; + +export interface TlsOptions extends Deno.ConnectTlsOptions { + mode: TlsMode; +} + +/** + * Aditional connection parameters + * + * @see {@link https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html#connecting-using-uri} + */ +export interface ConnectionParameters { + socket?: string; + sslMode?: TlsMode; + sslCa?: string[]; + sslCapath?: string[]; + sslCert?: string; + sslCipher?: string; + sslCrl?: string; + sslCrlpath?: string; + sslKey?: string; + tlsVersion?: string; + tlsVersions?: string; + tlsCiphersuites?: string; + authMethod?: string; + getServerPublicKey?: boolean; + serverPublicKeyPath?: string; + ssh?: string; + uri?: string; + sshPassword?: string; + sshConfigFile?: string; + sshIdentityFile?: string; + sshIdentityPass?: string; + connectTimeout?: number; + compression?: string; + compressionAlgorithms?: string; + compressionLevel?: string; + connectionAttributes?: string; +} + +export interface ConnectionConfig { + protocol: string; + username: string; + password?: string; + hostname: string; + port: number; + socket?: string; + schema?: string; + /** + * Tls options + */ + tls?: Partial; + /** + * Aditional connection parameters + */ + parameters: ConnectionParameters; +} + +export interface MysqlConnectionOptions extends SqlxConnectionOptions { +} /** Connection for mysql */ -export class Connection { +export class MysqlConnection + implements SqlxConnectable { state: ConnectionState = ConnectionState.CONNECTING; capabilities: number = 0; serverVersion: string = ""; @@ -58,6 +142,11 @@ export class Connection { protected _conn: Deno.Conn | null = null; private _timedOut = false; + readonly connectionUrl: string; + readonly connectionOptions: MysqlConnectionOptions; + readonly config: ConnectionConfig; + readonly sqlxVersion: string = VERSION; + get conn(): Deno.Conn { if (!this._conn) { throw new MysqlConnectionError("Not connected"); @@ -76,43 +165,48 @@ export class Connection { this._conn = conn; } - get remoteAddr(): string { - return this.config.socketPath - ? `unix:${this.config.socketPath}` - : `${this.config.hostname}:${this.config.port}`; + constructor( + connectionUrl: string | URL, + connectionOptions: MysqlConnectionOptions = {}, + ) { + this.connectionUrl = connectionUrl.toString().split("?")[0]; + this.connectionOptions = connectionOptions; + this.config = this.#parseConnectionConfig( + connectionUrl, + connectionOptions, + ); } - - get isMariaDB(): boolean { - return this.serverVersion.includes("MariaDB"); + get connected(): boolean { + return this.state === ConnectionState.CONNECTED; } - constructor(readonly config: ClientConfig) {} - - private async _connect() { + async connect(): Promise { // TODO: implement connect timeout if ( this.config.tls?.mode && - this.config.tls.mode !== TLSMode.DISABLED && - this.config.tls.mode !== TLSMode.VERIFY_IDENTITY + this.config.tls?.mode !== TlsMode.Disabled && + this.config.tls?.mode !== TlsMode.VerifyIdentity ) { - throw new MysqlError("unsupported tls mode"); + throw new Error("unsupported tls mode"); } - const { hostname, port = 3306, socketPath, username = "", password } = - this.config; - logger().info(`connecting ${this.remoteAddr}`); - this.conn = !socketPath - ? await Deno.connect({ - transport: "tcp", - hostname, - port, - }) - : await Deno.connect({ + + logger().info(`connecting ${this.connectionUrl}`); + + if (this.config.socket) { + this.conn = await Deno.connect({ transport: "unix", - path: socketPath, - } as any); + path: this.config.socket, + }); + } else { + this.conn = await Deno.connect({ + transport: "tcp", + hostname: this.config.hostname, + port: this.config.port, + }); + } try { - let receive = await this.nextPacket(); + let receive = await this.#nextPacket(); const handshakePacket = parseHandshake(receive.body); let handshakeSequenceNumber = receive.header.no; @@ -120,20 +214,20 @@ export class Connection { // Deno.startTls() only supports VERIFY_IDENTITY now. let isSSL = false; if ( - this.config.tls?.mode === TLSMode.VERIFY_IDENTITY + this.config.tls?.mode === TlsMode.VerifyIdentity ) { if ( (handshakePacket.serverCapabilities & ServerCapabilities.CLIENT_SSL) === 0 ) { - throw new MysqlError("Server does not support TLS"); + throw new Error("Server does not support TLS"); } if ( (handshakePacket.serverCapabilities & ServerCapabilities.CLIENT_SSL) !== 0 ) { const tlsData = buildSSLRequest(handshakePacket, { - db: this.config.db, + db: this.config.schema, }); await PacketWriter.write( this.conn, @@ -141,7 +235,7 @@ export class Connection { ++handshakeSequenceNumber, ); this.conn = await Deno.startTls(this.conn, { - hostname, + hostname: this.config.hostname, caCerts: this.config.tls?.caCerts, }); } @@ -149,19 +243,19 @@ export class Connection { } const data = await buildAuth(handshakePacket, { - username, - password, - db: this.config.db, + username: this.config.username, + password: this.config.password, + db: this.config.schema, ssl: isSSL, }); - await PacketWriter.write(this.conn, data, ++handshakeSequenceNumber); + await PacketWriter.write(this._conn!, data, ++handshakeSequenceNumber); this.state = ConnectionState.CONNECTING; this.serverVersion = handshakePacket.serverVersion; this.capabilities = handshakePacket.serverCapabilities; - receive = await this.nextPacket(); + receive = await this.#nextPacket(); const authResult = parseAuth(receive); let authPlugin: AuthPluginName | undefined = undefined; @@ -183,22 +277,26 @@ export class Connection { } let authData; - if (password) { + if (this.config.password) { authData = await auth( authSwitch.authPluginName, - password, + this.config.password, authSwitch.authPluginData, ); } else { authData = Uint8Array.from([]); } - await PacketWriter.write(this.conn, authData, receive.header.no + 1); + await PacketWriter.write( + this.conn, + authData, + receive.header.no + 1, + ); - receive = await this.nextPacket(); + receive = await this.#nextPacket(); const authSwitch2 = parseAuthSwitch(receive.body); if (authSwitch2.authPluginName !== "") { - throw new MysqlError( + throw new Error( "Do not allow to change the auth plugin more than once!", ); } @@ -221,10 +319,10 @@ export class Connection { plugin.data, sequenceNumber, ); - receive = await this.nextPacket(); + receive = await this.#nextPacket(); } if (plugin.quickRead) { - await this.nextPacket(); + await this.#nextPacket(); } await plugin.next(receive); @@ -232,7 +330,7 @@ export class Connection { break; } default: - throw new MysqlError("Unsupported auth plugin"); + throw new Error("Unsupported auth plugin"); } } @@ -241,15 +339,11 @@ export class Connection { const error = parseError(receive.body, this); logger().error(`connect error(${error.code}): ${error.message}`); this.close(); - throw new MysqlError(error.message); + throw new Error(error.message); } else { - logger().info(`connected to ${this.remoteAddr}`); + logger().info(`connected to ${this.connectionUrl}`); this.state = ConnectionState.CONNECTED; } - - if (this.config.charset) { - await this.execute(`SET NAMES ${this.config.charset}`); - } } catch (error) { // Call close() to avoid leaking socket. this.close(); @@ -257,25 +351,143 @@ export class Connection { } } - /** Connect to database */ - async connect(): Promise { - await this._connect(); + close(): Promise { + if (this.state != ConnectionState.CLOSED) { + logger().info("close connection"); + this._conn?.close(); + this.state = ConnectionState.CLOSED; + } + return Promise.resolve(); + } + + /** + * Parses the connection url and options into a connection config + */ + #parseConnectionConfig( + connectionUrl: string | URL, + connectionOptions: MysqlConnectionOptions, + ): ConnectionConfig { + function parseParameters(url: URL): ConnectionParameters { + const parameters: ConnectionParameters = {}; + for (const [key, value] of url.searchParams) { + const pKey = toCamelCase(key); + if (pKey === "sslCa") { + if (!parameters.sslCa) { + parameters.sslCa = []; + } + parameters.sslCa.push(value); + } else if (pKey === "sslCapath") { + if (!parameters.sslCapath) { + parameters.sslCapath = []; + } + parameters.sslCapath.push(value); + } else if (pKey === "getServerPublicKey") { + parameters.getServerPublicKey = value === "true"; + } else if (pKey === "connectTimeout") { + parameters.connectTimeout = parseInt(value); + } else { + // deno-lint-ignore no-explicit-any + parameters[pKey as keyof ConnectionParameters] = value as any; + } + } + return parameters; + } + + function parseTlsOptions(config: ConnectionConfig): TlsOptions | undefined { + const baseTlsOptions: TlsOptions = { + port: config.port, + hostname: config.hostname, + mode: TlsMode.Preferred, + }; + + if (connectionOptions.tls) { + return { + ...baseTlsOptions, + ...connectionOptions.tls, + }; + } + + if (config.parameters.sslMode) { + const tlsOptions: TlsOptions = { + ...baseTlsOptions, + mode: config.parameters.sslMode, + }; + + const caCertPaths = new Set(); + + if (config.parameters.sslCa?.length) { + for (const caCert of config.parameters.sslCa) { + caCertPaths.add(resolve(caCert)); + } + } + + if (config.parameters.sslCapath?.length) { + for (const caPath of config.parameters.sslCapath) { + for (const f of Deno.readDirSync(caPath)) { + if (f.isFile && f.name.endsWith(".pem")) { + caCertPaths.add(resolve(caPath, f.name)); + } + } + } + } + + if (caCertPaths.size) { + tlsOptions.caCerts = []; + for (const caCert of caCertPaths) { + const content = Deno.readTextFileSync(caCert); + tlsOptions.caCerts.push(content); + } + } + + if (config.parameters.sslKey) { + tlsOptions.key = Deno.readTextFileSync( + resolve(config.parameters.sslKey), + ); + } + + if (config.parameters.sslCert) { + tlsOptions.cert = Deno.readTextFileSync( + resolve(config.parameters.sslCert), + ); + } + + return tlsOptions; + } + return undefined; + } + + const url = new URL(connectionUrl); + const parameters = parseParameters(url); + const config: ConnectionConfig = { + protocol: url.protocol.slice(0, -1), + username: url.username, + password: url.password || undefined, + hostname: url.hostname, + port: parseInt(url.port || "3306"), + schema: url.pathname.slice(1), + parameters: parameters, + socket: parameters.socket, + }; + + config.tls = parseTlsOptions(config); + + return config; } - private async nextPacket(): Promise { - if (!this.conn) { + async #nextPacket(): Promise { + if (!this._conn) { throw new MysqlConnectionError("Not connected"); } - const timeoutTimer = this.config.timeout + const timeoutTimer = this.config.parameters.connectTimeout ? setTimeout( - this._timeoutCallback, - this.config.timeout, + this.#timeoutCallback, + this.config.parameters.connectTimeout, ) : null; let packet: PacketReader | null; try { - packet = await PacketReader.read(this.conn); + packet = await PacketReader.read(this._conn); } catch (error) { if (this._timedOut) { // Connection has been closed by timeoutCallback. @@ -296,62 +508,28 @@ export class Connection { if (packet.type === ComQueryResponsePacket.ERR_Packet) { packet.body.skip(1); const error = parseError(packet.body, this); - throw new MysqlError(error.message); + throw new Error(error.message); } - return packet!; + return packet; } - private _timeoutCallback = () => { + #timeoutCallback = () => { logger().info("connection read timed out"); this._timedOut = true; this.close(); }; - /** Close database connection */ - close(): void { - if (this.state != ConnectionState.CLOSED) { - logger().info("close connection"); - this.conn?.close(); - this.state = ConnectionState.CLOSED; - } - } - - /** - * excute query sql - * @param sql query sql string - * @param params query params - */ - async query(sql: string, params?: any[]): Promise { - const result = await this.execute(sql, params); - if (result && result.rows) { - return result.rows; - } else { - return result; - } - } - - /** - * execute sql - * @param sql sql string - * @param params query params - * @param iterator whether to return an ExecuteIteratorResult or ExecuteResult - */ - async execute( - sql: string, - params?: any[], - iterator = false, - ): Promise { - if (this.state != ConnectionState.CONNECTED) { - if (this.state == ConnectionState.CLOSED) { - throw new MysqlConnectionError("Connection is closed"); - } else { - throw new MysqlConnectionError("Must be connected first"); - } - } - const data = buildQuery(sql, params); + async *sendData( + data: Uint8Array, + options?: ConvertTypeOptions, + ): AsyncGenerator< + ConnectionSendDataNext, + ConnectionSendDataResult | undefined + > { try { await PacketWriter.write(this.conn, data, 0); - let receive = await this.nextPacket(); + let receive = await this.#nextPacket(); + logger().debug(`packet type: ${receive.type.toString()}`); if (receive.type === ComQueryResponsePacket.OK_Packet) { receive.body.skip(1); return { @@ -364,67 +542,78 @@ export class Connection { let fieldCount = receive.body.readEncodedLen(); const fields: FieldInfo[] = []; while (fieldCount--) { - const packet = await this.nextPacket(); + const packet = await this.#nextPacket(); if (packet) { const field = parseField(packet.body); fields.push(field); } } - const rows = []; if (!(this.capabilities & ServerCapabilities.CLIENT_DEPRECATE_EOF)) { // EOF(mysql < 5.7 or mariadb < 10.2) - receive = await this.nextPacket(); + receive = await this.#nextPacket(); if (receive.type !== ComQueryResponsePacket.EOF_Packet) { throw new MysqlProtocolError(receive.type.toString()); } } - if (!iterator) { - while (true) { - receive = await this.nextPacket(); - if (receive.type === ComQueryResponsePacket.EOF_Packet) { - break; - } else { - const row = parseRowObject(receive.body, fields); - rows.push(row); - } - } - return { rows, fields }; - } + receive = await this.#nextPacket(); - return { - fields, - iterator: this.buildIterator(fields), - }; + while (receive.type !== ComQueryResponsePacket.EOF_Packet) { + const row = parseRowArray(receive.body, fields, options); + yield { + row, + fields, + }; + receive = await this.#nextPacket(); + } } catch (error) { this.close(); throw error; } } - private buildIterator(fields: FieldInfo[]): any { - const next = async () => { - const receive = await this.nextPacket(); + async executeRaw( + data: Uint8Array, + options?: ConvertTypeOptions, + ): Promise { + const gen = this.sendData(data, options); + let result = await gen.next(); + if (result.done) { + return result.value?.affectedRows; + } - if (receive.type === ComQueryResponsePacket.EOF_Packet) { - return { done: true }; - } + const debugRest = []; + debugRest.push(result); + while (!result.done) { + result = await gen.next(); + debugRest.push(result); + logger().debug(`executeRaw overflow: ${JSON.stringify(debugRest)}`); + } + logger().debug(`executeRaw overflow: ${JSON.stringify(debugRest)}`); + return undefined; + } - const value = parseRowObject(receive.body, fields); + async *queryManyObjectRaw = Row>( + data: Uint8Array, + options?: ConvertTypeOptions, + ): AsyncIterableIterator { + for await (const res of this.sendData(data, options)) { + yield getRowObject(res.fields, res.row) as T; + } + } - return { - done: false, - value, - }; - }; + async *queryManyArrayRaw = ArrayRow>( + data: Uint8Array, + options?: ConvertTypeOptions, + ): AsyncIterableIterator { + for await (const res of this.sendData(data, options)) { + const row = res.row as T; + yield row as T; + } + } - return { - [Symbol.asyncIterator]: () => { - return { - next, - }; - }, - }; + async [Symbol.asyncDispose](): Promise { + await this.close(); } } diff --git a/lib/connection2.ts b/lib/connection2.ts deleted file mode 100644 index cc041e9..0000000 --- a/lib/connection2.ts +++ /dev/null @@ -1,590 +0,0 @@ -import { - MysqlConnectionError, - MysqlProtocolError, - MysqlReadError, - MysqlResponseTimeoutError, -} from "./utils/errors.ts"; -import { buildAuth } from "./packets/builders/auth.ts"; -import { PacketReader, PacketWriter } from "./packets/packet.ts"; -import { parseError } from "./packets/parsers/err.ts"; -import { - AuthResult, - parseAuth, - parseHandshake, -} from "./packets/parsers/handshake.ts"; -import { - type FieldInfo, - parseField, - parseRowArray, -} from "./packets/parsers/result.ts"; -import { ComQueryResponsePacket } from "./constant/packet.ts"; -import { AuthPlugins } from "./auth_plugins/mod.ts"; -import { parseAuthSwitch } from "./packets/parsers/authswitch.ts"; -import auth from "./utils/hash.ts"; -import { ServerCapabilities } from "./constant/capabilities.ts"; -import { buildSSLRequest } from "./packets/builders/tls.ts"; -import { logger } from "./utils/logger.ts"; -import type { - ArrayRow, - SqlxConnectable, - SqlxConnectionOptions, - SqlxParameterType, -} from "@halvardm/sqlx"; -import { VERSION } from "./utils/meta.ts"; -import { resolve } from "@std/path"; -import { toCamelCase } from "@std/text"; -import { AuthPluginName } from "./auth_plugins/mod.ts"; -export type MysqlParameterType = SqlxParameterType; - -/** - * Connection state - */ -export enum ConnectionState { - CONNECTING, - CONNECTED, - CLOSING, - CLOSED, -} - -export type ConnectionSendDataResult = { - affectedRows: number; - lastInsertId: number | null; -} | undefined; - -export type ConnectionSendDataNext = { - row: ArrayRow; - fields: FieldInfo[]; -}; - -export interface ConnectionOptions extends SqlxConnectionOptions { -} - -/** - * Tls mode for mysql connection - * - * @see {@link https://dev.mysql.com/doc/refman/8.0/en/connection-options.html#option_general_ssl-mode} - */ -export const TlsMode = { - Preferred: "PREFERRED", - Disabled: "DISABLED", - Required: "REQUIRED", - VerifyCa: "VERIFY_CA", - VerifyIdentity: "VERIFY_IDENTITY", -} as const; -export type TlsMode = typeof TlsMode[keyof typeof TlsMode]; - -export interface TlsOptions extends Deno.ConnectTlsOptions { - mode: TlsMode; -} - -/** - * Aditional connection parameters - * - * @see {@link https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html#connecting-using-uri} - */ -export interface ConnectionParameters { - socket?: string; - sslMode?: TlsMode; - sslCa?: string[]; - sslCapath?: string[]; - sslCert?: string; - sslCipher?: string; - sslCrl?: string; - sslCrlpath?: string; - sslKey?: string; - tlsVersion?: string; - tlsVersions?: string; - tlsCiphersuites?: string; - authMethod?: string; - getServerPublicKey?: boolean; - serverPublicKeyPath?: string; - ssh?: string; - uri?: string; - sshPassword?: string; - sshConfigFile?: string; - sshIdentityFile?: string; - sshIdentityPass?: string; - connectTimeout?: number; - compression?: string; - compressionAlgorithms?: string; - compressionLevel?: string; - connectionAttributes?: string; -} - -export interface ConnectionConfig { - protocol: string; - username: string; - password?: string; - hostname: string; - port: number; - socket?: string; - schema?: string; - /** - * Tls options - */ - tls?: Partial; - /** - * Aditional connection parameters - */ - parameters: ConnectionParameters; -} - -/** Connection for mysql */ -export class MysqlConnection implements SqlxConnectable { - state: ConnectionState = ConnectionState.CONNECTING; - capabilities: number = 0; - serverVersion: string = ""; - - protected _conn: Deno.Conn | null = null; - private _timedOut = false; - - readonly connectionUrl: string; - readonly connectionOptions: ConnectionOptions; - readonly config: ConnectionConfig; - readonly sqlxVersion: string = VERSION; - - get conn(): Deno.Conn { - if (!this._conn) { - throw new MysqlConnectionError("Not connected"); - } - if (this.state != ConnectionState.CONNECTED) { - if (this.state == ConnectionState.CLOSED) { - throw new MysqlConnectionError("Connection is closed"); - } else { - throw new MysqlConnectionError("Must be connected first"); - } - } - return this._conn; - } - - set conn(conn: Deno.Conn | null) { - this._conn = conn; - } - - constructor( - connectionUrl: string | URL, - connectionOptions: ConnectionOptions = {}, - ) { - this.connectionUrl = connectionUrl.toString().split("?")[0]; - this.connectionOptions = connectionOptions; - this.config = this.#parseConnectionConfig( - connectionUrl, - connectionOptions, - ); - } - get connected(): boolean { - return this.state === ConnectionState.CONNECTED; - } - - async connect(): Promise { - // TODO: implement connect timeout - if ( - this.config.tls?.mode && - this.config.tls?.mode !== TlsMode.Disabled && - this.config.tls?.mode !== TlsMode.VerifyIdentity - ) { - throw new Error("unsupported tls mode"); - } - - logger().info(`connecting ${this.connectionUrl}`); - - if (this.config.socket) { - this.conn = await Deno.connect({ - transport: "unix", - path: this.config.socket, - }); - } else { - this.conn = await Deno.connect({ - transport: "tcp", - hostname: this.config.hostname, - port: this.config.port, - }); - } - - try { - let receive = await this.#nextPacket(); - const handshakePacket = parseHandshake(receive.body); - - let handshakeSequenceNumber = receive.header.no; - - // Deno.startTls() only supports VERIFY_IDENTITY now. - let isSSL = false; - if ( - this.config.tls?.mode === TlsMode.VerifyIdentity - ) { - if ( - (handshakePacket.serverCapabilities & - ServerCapabilities.CLIENT_SSL) === 0 - ) { - throw new Error("Server does not support TLS"); - } - if ( - (handshakePacket.serverCapabilities & - ServerCapabilities.CLIENT_SSL) !== 0 - ) { - const tlsData = buildSSLRequest(handshakePacket, { - db: this.config.schema, - }); - await PacketWriter.write( - this.conn, - tlsData, - ++handshakeSequenceNumber, - ); - this.conn = await Deno.startTls(this.conn, { - hostname: this.config.hostname, - caCerts: this.config.tls?.caCerts, - }); - } - isSSL = true; - } - - const data = await buildAuth(handshakePacket, { - username: this.config.username, - password: this.config.password, - db: this.config.schema, - ssl: isSSL, - }); - - await PacketWriter.write(this._conn!, data, ++handshakeSequenceNumber); - - this.state = ConnectionState.CONNECTING; - this.serverVersion = handshakePacket.serverVersion; - this.capabilities = handshakePacket.serverCapabilities; - - receive = await this.#nextPacket(); - - const authResult = parseAuth(receive); - let authPlugin: AuthPluginName | undefined = undefined; - - switch (authResult) { - case AuthResult.AuthMoreRequired: { - authPlugin = handshakePacket.authPluginName as AuthPluginName; - break; - } - case AuthResult.MethodMismatch: { - const authSwitch = parseAuthSwitch(receive.body); - // If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is - // sent and we have to keep using the cipher sent in the init packet. - if ( - authSwitch.authPluginData === undefined || - authSwitch.authPluginData.length === 0 - ) { - authSwitch.authPluginData = handshakePacket.seed; - } - - let authData; - if (this.config.password) { - authData = await auth( - authSwitch.authPluginName, - this.config.password, - authSwitch.authPluginData, - ); - } else { - authData = Uint8Array.from([]); - } - - await PacketWriter.write( - this.conn, - authData, - receive.header.no + 1, - ); - - receive = await this.#nextPacket(); - const authSwitch2 = parseAuthSwitch(receive.body); - if (authSwitch2.authPluginName !== "") { - throw new Error( - "Do not allow to change the auth plugin more than once!", - ); - } - } - } - - if (authPlugin) { - switch (authPlugin) { - case AuthPluginName.CachingSha2Password: { - const plugin = new AuthPlugins[authPlugin]( - handshakePacket.seed, - this.config.password!, - ); - - while (!plugin.done) { - if (plugin.data) { - const sequenceNumber = receive.header.no + 1; - await PacketWriter.write( - this.conn, - plugin.data, - sequenceNumber, - ); - receive = await this.#nextPacket(); - } - if (plugin.quickRead) { - await this.#nextPacket(); - } - - await plugin.next(receive); - } - break; - } - default: - throw new Error("Unsupported auth plugin"); - } - } - - const header = receive.body.readUint8(); - if (header === 0xff) { - const error = parseError(receive.body, this as any); - logger().error(`connect error(${error.code}): ${error.message}`); - this.close(); - throw new Error(error.message); - } else { - logger().info(`connected to ${this.connectionUrl}`); - this.state = ConnectionState.CONNECTED; - } - } catch (error) { - // Call close() to avoid leaking socket. - this.close(); - throw error; - } - } - - async close(): Promise { - if (this.state != ConnectionState.CLOSED) { - logger().info("close connection"); - this._conn?.close(); - this.state = ConnectionState.CLOSED; - } - } - - /** - * Parses the connection url and options into a connection config - */ - #parseConnectionConfig( - connectionUrl: string | URL, - connectionOptions: ConnectionOptions, - ): ConnectionConfig { - function parseParameters(url: URL): ConnectionParameters { - const parameters: ConnectionParameters = {}; - for (const [key, value] of url.searchParams) { - const pKey = toCamelCase(key); - if (pKey === "sslCa") { - if (!parameters.sslCa) { - parameters.sslCa = []; - } - parameters.sslCa.push(value); - } else if (pKey === "sslCapath") { - if (!parameters.sslCapath) { - parameters.sslCapath = []; - } - parameters.sslCapath.push(value); - } else if (pKey === "getServerPublicKey") { - parameters.getServerPublicKey = value === "true"; - } else if (pKey === "connectTimeout") { - parameters.connectTimeout = parseInt(value); - } else { - parameters[pKey as keyof ConnectionParameters] = value as any; - } - } - return parameters; - } - - function parseTlsOptions(config: ConnectionConfig): TlsOptions | undefined { - const baseTlsOptions: TlsOptions = { - port: config.port, - hostname: config.hostname, - mode: TlsMode.Preferred, - }; - - if (connectionOptions.tls) { - return { - ...baseTlsOptions, - ...connectionOptions.tls, - }; - } - - if (config.parameters.sslMode) { - const tlsOptions: TlsOptions = { - ...baseTlsOptions, - mode: config.parameters.sslMode, - }; - - const caCertPaths = new Set(); - - if (config.parameters.sslCa?.length) { - for (const caCert of config.parameters.sslCa) { - caCertPaths.add(resolve(caCert)); - } - } - - if (config.parameters.sslCapath?.length) { - for (const caPath of config.parameters.sslCapath) { - for (const f of Deno.readDirSync(caPath)) { - if (f.isFile && f.name.endsWith(".pem")) { - caCertPaths.add(resolve(caPath, f.name)); - } - } - } - } - - if (caCertPaths.size) { - tlsOptions.caCerts = []; - for (const caCert of caCertPaths) { - const content = Deno.readTextFileSync(caCert); - tlsOptions.caCerts.push(content); - } - } - - if (config.parameters.sslKey) { - tlsOptions.key = Deno.readTextFileSync( - resolve(config.parameters.sslKey), - ); - } - - if (config.parameters.sslCert) { - tlsOptions.cert = Deno.readTextFileSync( - resolve(config.parameters.sslCert), - ); - } - - return tlsOptions; - } - return undefined; - } - - const url = new URL(connectionUrl); - const parameters = parseParameters(url); - const config: ConnectionConfig = { - protocol: url.protocol.slice(0, -1), - username: url.username, - password: url.password || undefined, - hostname: url.hostname, - port: parseInt(url.port || "3306"), - schema: url.pathname.slice(1), - parameters: parameters, - socket: parameters.socket, - }; - - config.tls = parseTlsOptions(config); - - return config; - } - - async #nextPacket(): Promise { - if (!this._conn) { - throw new MysqlConnectionError("Not connected"); - } - - const timeoutTimer = this.config.parameters.connectTimeout - ? setTimeout( - this.#timeoutCallback, - this.config.parameters.connectTimeout, - ) - : null; - let packet: PacketReader | null; - try { - packet = await PacketReader.read(this._conn); - } catch (error) { - if (this._timedOut) { - // Connection has been closed by timeoutCallback. - throw new MysqlResponseTimeoutError("Connection read timed out"); - } - timeoutTimer && clearTimeout(timeoutTimer); - this.close(); - throw error; - } - timeoutTimer && clearTimeout(timeoutTimer); - - if (!packet) { - // Connection is half-closed by the remote host. - // Call close() to avoid leaking socket. - this.close(); - throw new MysqlReadError("Connection closed unexpectedly"); - } - if (packet.type === ComQueryResponsePacket.ERR_Packet) { - packet.body.skip(1); - const error = parseError(packet.body, this as any); - throw new Error(error.message); - } - return packet!; - } - - #timeoutCallback = () => { - logger().info("connection read timed out"); - this._timedOut = true; - this.close(); - }; - - async *sendData( - data: Uint8Array, - ): AsyncGenerator { - try { - await PacketWriter.write(this.conn, data, 0); - let receive = await this.#nextPacket(); - if (receive.type === ComQueryResponsePacket.OK_Packet) { - receive.body.skip(1); - return { - affectedRows: receive.body.readEncodedLen(), - lastInsertId: receive.body.readEncodedLen(), - }; - } else if (receive.type !== ComQueryResponsePacket.Result) { - throw new MysqlProtocolError(receive.type.toString()); - } - let fieldCount = receive.body.readEncodedLen(); - const fields: FieldInfo[] = []; - while (fieldCount--) { - const packet = await this.#nextPacket(); - if (packet) { - const field = parseField(packet.body); - fields.push(field); - } - } - - if (!(this.capabilities & ServerCapabilities.CLIENT_DEPRECATE_EOF)) { - // EOF(mysql < 5.7 or mariadb < 10.2) - receive = await this.#nextPacket(); - if (receive.type !== ComQueryResponsePacket.EOF_Packet) { - throw new MysqlProtocolError(receive.type.toString()); - } - } - - receive = await this.#nextPacket(); - - while (receive.type !== ComQueryResponsePacket.EOF_Packet) { - const row = parseRowArray(receive.body, fields); - yield { row, fields }; - receive = await this.#nextPacket(); - } - } catch (error) { - this.close(); - throw error; - } - } - - async execute( - data: Uint8Array, - ): Promise { - try { - await PacketWriter.write(this.conn, data, 0); - const receive = await this.#nextPacket(); - if (receive.type === ComQueryResponsePacket.OK_Packet) { - receive.body.skip(1); - return { - affectedRows: receive.body.readEncodedLen(), - lastInsertId: receive.body.readEncodedLen(), - }; - } else if (receive.type !== ComQueryResponsePacket.Result) { - throw new MysqlProtocolError(receive.type.toString()); - } - return { - affectedRows: 0, - lastInsertId: null, - }; - } catch (error) { - this.close(); - throw error; - } - } - - async [Symbol.asyncDispose](): Promise { - await this.close(); - } -} diff --git a/lib/packets/builders/query.ts b/lib/packets/builders/query.ts index ed592bd..0ba9d75 100644 --- a/lib/packets/builders/query.ts +++ b/lib/packets/builders/query.ts @@ -1,9 +1,13 @@ import { replaceParams } from "../../utils/query.ts"; import { BufferWriter } from "../../utils/buffer.ts"; import { encode } from "../../utils/encoding.ts"; +import type { MysqlParameterType } from "../parsers/result.ts"; /** @ignore */ -export function buildQuery(sql: string, params: any[] = []): Uint8Array { +export function buildQuery( + sql: string, + params: MysqlParameterType[] = [], +): Uint8Array { const data = encode(replaceParams(sql, params)); const writer = new BufferWriter(new Uint8Array(data.length + 1)); writer.write(0x03); diff --git a/lib/packets/parsers/err.ts b/lib/packets/parsers/err.ts index 448405c..dac14ef 100644 --- a/lib/packets/parsers/err.ts +++ b/lib/packets/parsers/err.ts @@ -1,5 +1,5 @@ import type { BufferReader } from "../../utils/buffer.ts"; -import type { Connection } from "../../connection.ts"; +import type { MysqlConnection } from "../../connection.ts"; import { ServerCapabilities } from "../../constant/capabilities.ts"; /** @ignore */ @@ -13,7 +13,7 @@ export interface ErrorPacket { /** @ignore */ export function parseError( reader: BufferReader, - conn: Connection, + conn: MysqlConnection, ): ErrorPacket { const code = reader.readUint16(); const packet: ErrorPacket = { diff --git a/lib/packets/parsers/result.ts b/lib/packets/parsers/result.ts index 08ce798..eb9c525 100644 --- a/lib/packets/parsers/result.ts +++ b/lib/packets/parsers/result.ts @@ -1,6 +1,11 @@ import type { BufferReader } from "../../utils/buffer.ts"; import { MysqlDataType } from "../../constant/mysql_types.ts"; -import type { ArrayRow, Row, SqlxParameterType } from "@halvardm/sqlx"; +import type { + ArrayRow, + Row, + SqlxParameterType, + SqlxQueryOptions, +} from "@halvardm/sqlx"; export type MysqlParameterType = SqlxParameterType< string | number | bigint | Date | null @@ -24,6 +29,8 @@ export interface FieldInfo { defaultVal: string; } +export type ConvertTypeOptions = Pick; + /** * Parses the field */ @@ -64,11 +71,12 @@ export function parseField(reader: BufferReader): FieldInfo { export function parseRowArray( reader: BufferReader, fields: FieldInfo[], + options?: ConvertTypeOptions, ): ArrayRow { const row: MysqlParameterType[] = []; for (const field of fields) { const val = reader.readLenCodeString(); - const parsedVal = val === null ? null : convertType(field, val); + const parsedVal = val === null ? null : convertType(field, val, options); row.push(parsedVal); } return row; @@ -100,7 +108,15 @@ export function getRowObject( /** * Converts the value to the correct type */ -function convertType(field: FieldInfo, val: string): MysqlParameterType { +function convertType( + field: FieldInfo, + val: string, + options?: ConvertTypeOptions, +): MysqlParameterType { + if (options?.transformType) { + // deno-lint-ignore no-explicit-any + return options.transformType(val) as any; + } const { fieldType } = field; switch (fieldType) { case MysqlDataType.Decimal: diff --git a/lib/utils/logger.ts b/lib/utils/logger.ts index f70b71d..9cd7c34 100644 --- a/lib/utils/logger.ts +++ b/lib/utils/logger.ts @@ -1,4 +1,4 @@ -import { getLogger } from "@std/log"; +import { ConsoleHandler, getLogger, setup } from "@std/log"; import { MODULE_NAME } from "./meta.ts"; /** @@ -10,3 +10,20 @@ import { MODULE_NAME } from "./meta.ts"; export function logger() { return getLogger(MODULE_NAME); } + +setup({ + handlers: { + console: new ConsoleHandler("DEBUG"), + }, + loggers: { + // configure default logger available via short-hand methods above + default: { + level: "INFO", + handlers: ["console"], + }, + [MODULE_NAME]: { + level: "INFO", + handlers: ["console"], + }, + }, +}); diff --git a/lib/utils/query.ts b/lib/utils/query.ts index 896a565..4179d1c 100644 --- a/lib/utils/query.ts +++ b/lib/utils/query.ts @@ -1,9 +1,14 @@ +import type { MysqlParameterType } from "../packets/parsers/result.ts"; + /** * Replaces parameters in a SQL query with the given values. * * Taken from https://github.com/manyuanrong/sql-builder/blob/master/util.ts */ -export function replaceParams(sql: string, params: any | any[]): string { +export function replaceParams( + sql: string, + params: MysqlParameterType[], +): string { if (!params) return sql; let paramIndex = 0; sql = sql.replace( @@ -46,9 +51,10 @@ export function replaceParams(sql: string, params: any | any[]): string { // deno-lint-ignore no-fallthrough case "object": if (val instanceof Date) return `"${formatDate(val)}"`; - if (val instanceof Array) { + if ((val as unknown) instanceof Array) { return `(${ - val.map((item) => replaceParams("?", [item])).join(",") + (val as Array).map((item) => replaceParams("?", [item])) + .join(",") })`; } case "string": @@ -58,7 +64,7 @@ export function replaceParams(sql: string, params: any | any[]): string { case "number": case "boolean": default: - return val; + return val.toString(); } }, ); @@ -69,6 +75,7 @@ export function replaceParams(sql: string, params: any | any[]): string { * Formats date to a 'YYYY-MM-DD HH:MM:SS.SSS' string. */ function formatDate(date: Date) { + date.toISOString(); const year = date.getFullYear(); const month = (date.getMonth() + 1).toString().padStart(2, "0"); const days = date