Skip to content

Commit

Permalink
Add extension functions for the ResultSet (#772)
Browse files Browse the repository at this point in the history
* Add extension functions for the ResultSet

* added extension functions for Connection, DatabaseConfiguration

* Refactor database configuration and dataframe methods.

Renamed `DatabaseConfiguration` to `DbConnectionConfig` for clarity. Replaced `.toDF` with `.readDataFrame` methods to improve method naming consistency. These changes enhance code readability and maintainability.

* Refactor SQL reading and schema functions in readJdbc.kt

Simplify the logic to use single-expression functions for readability. Ensure consistent formatting and make error messages more explicit. This change also corrects minor indentation issues in SQL query strings within tests.

* Rename DatabaseConfiguration to DbConnectionConfig for consistency

This commit updates various imports and references from DatabaseConfiguration to DbConnectionConfig across different files. This change ensures consistency in the naming convention used throughout the codebase and documentation, improving clarity and maintenance.
  • Loading branch information
zaleslaw authored Jul 29, 2024
1 parent 73ba813 commit 958f1df
Show file tree
Hide file tree
Showing 5 changed files with 385 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public data class TableMetadata(val name: String, val schemaName: String?, val c
* @property [user] the username used for authentication (optional, default is empty string).
* @property [password] the password used for authentication (optional, default is empty string).
*/
public data class DatabaseConfiguration(val url: String, val user: String = "", val password: String = "")
public data class DbConnectionConfig(val url: String, val user: String = "", val password: String = "")

/**
* Reads data from an SQL table and converts it into a DataFrame.
Expand All @@ -110,7 +110,7 @@ public data class DatabaseConfiguration(val url: String, val user: String = "",
* @return the DataFrame containing the data from the SQL table.
*/
public fun DataFrame.Companion.readSqlTable(
dbConfig: DatabaseConfiguration,
dbConfig: DbConnectionConfig,
tableName: String,
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
Expand Down Expand Up @@ -169,7 +169,7 @@ public fun DataFrame.Companion.readSqlTable(
* @return the DataFrame containing the result of the SQL query.
*/
public fun DataFrame.Companion.readSqlQuery(
dbConfig: DatabaseConfiguration,
dbConfig: DbConnectionConfig,
sqlQuery: String,
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
Expand Down Expand Up @@ -218,6 +218,89 @@ public fun DataFrame.Companion.readSqlQuery(
}
}

/**
* Converts the result of an SQL query or SQL table (by name) to the DataFrame.
*
* @param [sqlQueryOrTableName] the SQL query to execute or name of the SQL table.
* It should be a name of one of the existing SQL tables,
* or the SQL query should start from SELECT and contain one query for reading data without any manipulation.
* It should not contain `;` symbol.
* @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame containing the result of the SQL query.
*/
public fun DbConnectionConfig.readDataFrame(
sqlQueryOrTableName: String,
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame =
when {
isSqlQuery(sqlQueryOrTableName) -> DataFrame.readSqlQuery(
this,
sqlQueryOrTableName,
limit,
inferNullability,
)

isSqlTableName(sqlQueryOrTableName) -> DataFrame.readSqlTable(
this,
sqlQueryOrTableName,
limit,
inferNullability,
)

else -> throw IllegalArgumentException(
"$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!",
)
}

private fun isSqlQuery(sqlQueryOrTableName: String): Boolean {
val queryPattern = Regex("(?i)\\b(SELECT)\\b")
return queryPattern.containsMatchIn(sqlQueryOrTableName.trim())
}

private fun isSqlTableName(sqlQueryOrTableName: String): Boolean {
// Match table names with optional schema and catalog (e.g., catalog.schema.table)
val tableNamePattern = Regex("^[a-zA-Z_][a-zA-Z0-9_]*(\\.[a-zA-Z_][a-zA-Z0-9_]*){0,2}$")
return tableNamePattern.matches(sqlQueryOrTableName.trim())
}

/**
* Converts the result of an SQL query or SQL table (by name) to the DataFrame.
*
* @param [sqlQueryOrTableName] the SQL query to execute or name of the SQL table.
* It should be a name of one of the existing SQL tables,
* or the SQL query should start from SELECT and contain one query for reading data without any manipulation.
* It should not contain `;` symbol.
* @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame containing the result of the SQL query.
*/
public fun Connection.readDataFrame(
sqlQueryOrTableName: String,
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame =
when {
isSqlQuery(sqlQueryOrTableName) -> DataFrame.readSqlQuery(
this,
sqlQueryOrTableName,
limit,
inferNullability,
)

isSqlTableName(sqlQueryOrTableName) -> DataFrame.readSqlTable(
this,
sqlQueryOrTableName,
limit,
inferNullability,
)

else -> throw IllegalArgumentException(
"$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!",
)
}

/** SQL query is accepted only if it starts from SELECT */
private fun isValid(sqlQuery: String): Boolean {
val normalizedSqlQuery = sqlQuery.trim().uppercase()
Expand Down Expand Up @@ -256,6 +339,30 @@ public fun DataFrame.Companion.readResultSet(
return fetchAndConvertDataFromResultSet(tableColumns, resultSet, dbType, limit, inferNullability)
}

/**
* Reads the data from a [ResultSet][java.sql.ResultSet] and converts it into a DataFrame.
*
* A [ResultSet][java.sql.ResultSet] object maintains a cursor pointing to its current row of data.
* By default, a ResultSet object is not updatable and has a cursor that can only move forward.
* Therefore, you can iterate through it only once, from the first row to the last row.
*
* For more details, refer to the official Java documentation on [ResultSet][java.sql.ResultSet].
*
* NOTE: Reading from the [ResultSet][java.sql.ResultSet] could potentially change its state.
*
* @param [dbType] the type of database that the [ResultSet] belongs to.
* @param [limit] the maximum number of rows to read from the [ResultSet][java.sql.ResultSet].
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame generated from the [ResultSet][java.sql.ResultSet] data.
*
* [java.sql.ResultSet]: https://docs.oracle.com/javase/8/docs/api/java/sql/ResultSet.html
*/
public fun ResultSet.readDataFrame(
dbType: DbType,
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame = DataFrame.Companion.readResultSet(this, dbType, limit, inferNullability)

/**
* Reads the data from a [ResultSet][java.sql.ResultSet] and converts it into a DataFrame.
*
Expand Down Expand Up @@ -288,6 +395,31 @@ public fun DataFrame.Companion.readResultSet(
return readResultSet(resultSet, dbType, limit, inferNullability)
}

/**
* Reads the data from a [ResultSet][java.sql.ResultSet] and converts it into a DataFrame.
*
* A [ResultSet][java.sql.ResultSet] object maintains a cursor pointing to its current row of data.
* By default, a ResultSet object is not updatable and has a cursor that can only move forward.
* Therefore, you can iterate through it only once, from the first row to the last row.
*
* For more details, refer to the official Java documentation on [ResultSet][java.sql.ResultSet].
*
* NOTE: Reading from the [ResultSet][java.sql.ResultSet] could potentially change its state.
*
* @param [connection] the connection to the database (it's required to extract the database type)
* that the [ResultSet] belongs to.
* @param [limit] the maximum number of rows to read from the [ResultSet][java.sql.ResultSet].
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame generated from the [ResultSet][java.sql.ResultSet] data.
*
* [java.sql.ResultSet]: https://docs.oracle.com/javase/8/docs/api/java/sql/ResultSet.html
*/
public fun ResultSet.readDataFrame(
connection: Connection,
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame = DataFrame.Companion.readResultSet(this, connection, limit, inferNullability)

/**
* Reads all non-system tables from a database and returns them
* as a map of SQL tables and corresponding dataframes using the provided database configuration and limit.
Expand All @@ -299,7 +431,7 @@ public fun DataFrame.Companion.readResultSet(
* @return a map of [String] to [AnyFrame] objects representing the non-system tables from the database.
*/
public fun DataFrame.Companion.readAllSqlTables(
dbConfig: DatabaseConfiguration,
dbConfig: DbConnectionConfig,
catalogue: String? = null,
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
Expand Down Expand Up @@ -366,10 +498,7 @@ public fun DataFrame.Companion.readAllSqlTables(
* @param [tableName] the name of the SQL table for which to retrieve the schema.
* @return the [DataFrameSchema] object representing the schema of the SQL table
*/
public fun DataFrame.Companion.getSchemaForSqlTable(
dbConfig: DatabaseConfiguration,
tableName: String,
): DataFrameSchema {
public fun DataFrame.Companion.getSchemaForSqlTable(dbConfig: DbConnectionConfig, tableName: String): DataFrameSchema {
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
return getSchemaForSqlTable(connection, tableName)
}
Expand Down Expand Up @@ -405,10 +534,7 @@ public fun DataFrame.Companion.getSchemaForSqlTable(connection: Connection, tabl
* @param [sqlQuery] the SQL query to execute and retrieve the schema from.
* @return the schema of the SQL query as a [DataFrameSchema] object.
*/
public fun DataFrame.Companion.getSchemaForSqlQuery(
dbConfig: DatabaseConfiguration,
sqlQuery: String,
): DataFrameSchema {
public fun DataFrame.Companion.getSchemaForSqlQuery(dbConfig: DbConnectionConfig, sqlQuery: String): DataFrameSchema {
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
return getSchemaForSqlQuery(connection, sqlQuery)
}
Expand All @@ -434,6 +560,40 @@ public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQ
}
}

/**
* Retrieves the schema of an SQL query result or the SQL table using the provided database configuration.
*
* @param [sqlQueryOrTableName] the SQL query to execute and retrieve the schema from.
* @return the schema of the SQL query as a [DataFrameSchema] object.
*/
public fun DbConnectionConfig.getDataFrameSchema(sqlQueryOrTableName: String): DataFrameSchema =
when {
isSqlQuery(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlQuery(this, sqlQueryOrTableName)

isSqlTableName(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlTable(this, sqlQueryOrTableName)

else -> throw IllegalArgumentException(
"$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!",
)
}

/**
* Retrieves the schema of an SQL query result or the SQL table using the provided database configuration.
*
* @param [sqlQueryOrTableName] the SQL query to execute and retrieve the schema from.
* @return the schema of the SQL query as a [DataFrameSchema] object.
*/
public fun Connection.getDataFrameSchema(sqlQueryOrTableName: String): DataFrameSchema =
when {
isSqlQuery(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlQuery(this, sqlQueryOrTableName)

isSqlTableName(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlTable(this, sqlQueryOrTableName)

else -> throw IllegalArgumentException(
"$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!",
)
}

/**
* Retrieves the schema from [ResultSet].
*
Expand All @@ -448,6 +608,16 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbTyp
return buildSchemaByTableColumns(tableColumns, dbType)
}

/**
* Retrieves the schema from [ResultSet].
*
* NOTE: This function will not close connection and result set and not retrieve data from the result set.
*
* @param [dbType] the type of database that the [ResultSet] belongs to.
* @return the schema of the [ResultSet] as a [DataFrameSchema] object.
*/
public fun ResultSet.getDataFrameSchema(dbType: DbType): DataFrameSchema = DataFrame.getSchemaForResultSet(this, dbType)

/**
* Retrieves the schema from [ResultSet].
*
Expand All @@ -465,13 +635,24 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, conne
return buildSchemaByTableColumns(tableColumns, dbType)
}

/**
* Retrieves the schema from [ResultSet].
*
* NOTE: This function will not close connection and result set and not retrieve data from the result set.
*
* @param [connection] the connection to the database (it's required to extract the database type).
* @return the schema of the [ResultSet] as a [DataFrameSchema] object.
*/
public fun ResultSet.getDataFrameSchema(connection: Connection): DataFrameSchema =
DataFrame.getSchemaForResultSet(this, connection)

/**
* Retrieves the schemas of all non-system tables in the database using the provided database configuration.
*
* @param [dbConfig] the database configuration to connect to the database, including URL, user, and password.
* @return a map of [String, DataFrameSchema] objects representing the table name and its schema for each non-system table.
*/
public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): Map<String, DataFrameSchema> {
public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DbConnectionConfig): Map<String, DataFrameSchema> {
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
return getSchemaForAllSqlTables(connection)
}
Expand Down
Loading

0 comments on commit 958f1df

Please sign in to comment.