Skip to content

Commit

Permalink
Optimize executescript() to use batching
Browse files Browse the repository at this point in the history
Refs #70
  • Loading branch information
penberg committed Aug 10, 2024
1 parent b15302e commit 8b40f28
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
23 changes: 13 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,9 @@ impl Connection {
}

fn executescript(self_: PyRef<'_, Self>, script: String) -> PyResult<()> {
let statements = script.split(';');
for statement in statements {
let statement = statement.trim();
if !statement.is_empty() {
let cursor = Connection::cursor(&self_)?;
self_
.rt
.block_on(async { execute(&cursor, statement.to_string(), None).await })?;
}
}
let _ = self_.rt.block_on(async {
self_.conn.execute_batch(&script).await
}).map_err(to_py_err);
Ok(())
}

Expand Down Expand Up @@ -272,6 +265,16 @@ impl Cursor {
Ok(self_)
}

fn executescript<'a>(self_: PyRef<'a, Self>, script: String) -> PyResult<pyo3::PyRef<'a, Self>> {
self_
.rt
.block_on(async {
self_.conn.execute_batch(&script).await
})
.map_err(to_py_err)?;
Ok(self_)
}

#[getter]
fn description(self_: PyRef<'_, Self>) -> PyResult<Option<&PyTuple>> {
let stmt = self_.stmt.borrow();
Expand Down
13 changes: 13 additions & 0 deletions tests/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ def test_cursor_executemany(provider):
res = cur.execute("SELECT * FROM users")
assert [(1, '[email protected]'), (2, '[email protected]')] == res.fetchall()

@pytest.mark.parametrize("provider", ["libsql", "sqlite"])
def test_cursor_executescript(provider):
conn = connect(provider, ":memory:")
cur = conn.cursor()
cur.executescript("""
CREATE TABLE users (id INTEGER, email TEXT);
INSERT INTO users VALUES (1, '[email protected]');
INSERT INTO users VALUES (2, '[email protected]');
""")
res = cur.execute("SELECT * FROM users")
assert (1, '[email protected]') == res.fetchone()
assert (2, '[email protected]') == res.fetchone()

@pytest.mark.parametrize("provider", ["libsql", "sqlite"])
def test_lastrowid(provider):
conn = connect(provider, ":memory:")
Expand Down

0 comments on commit 8b40f28

Please sign in to comment.