use std::ffi::{c_int, CString}; use std::marker::PhantomData; use std::{slice, str}; use anyhow::{anyhow, Context, Result}; use libsqlite3_sys::*; use crate::bindable::{Bind, Column}; use crate::connection::Connection; pub struct Statement<'a> { raw_statement: *mut sqlite3_stmt, connection: &'a Connection, phantom: PhantomData, } #[derive(Clone, Copy, PartialEq, Eq, Debug)] pub enum StepResult { Row, Done, Misuse, Other(i32), } #[derive(Clone, Copy, PartialEq, Eq, Debug)] pub enum SqlType { Text, Integer, Blob, Float, Null, } impl<'a> Statement<'a> { pub fn prepare>(connection: &'a Connection, query: T) -> Result { let mut statement = Self { raw_statement: 0 as *mut _, connection, phantom: PhantomData, }; unsafe { sqlite3_prepare_v2( connection.sqlite3, CString::new(query.as_ref())?.as_ptr(), -1, &mut statement.raw_statement, 0 as *mut _, ); connection .last_error() .with_context(|| format!("Prepare call failed for query:\n{}", query.as_ref()))?; } Ok(statement) } pub fn reset(&mut self) { unsafe { sqlite3_reset(self.raw_statement); } } pub fn parameter_count(&self) -> i32 { unsafe { sqlite3_bind_parameter_count(self.raw_statement) } } pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> { // dbg!("bind blob", index); let index = index as c_int; let blob_pointer = blob.as_ptr() as *const _; let len = blob.len() as c_int; unsafe { sqlite3_bind_blob( self.raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT(), ); } self.connection.last_error() } pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> { let index = index as c_int; let pointer = unsafe { sqlite3_column_blob(self.raw_statement, index) }; self.connection.last_error()?; if pointer.is_null() { return Ok(&[]); } let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize }; self.connection.last_error()?; unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) } } pub fn bind_double(&self, index: i32, double: f64) -> Result<()> { // dbg!("bind double", index); let index = index as c_int; unsafe { sqlite3_bind_double(self.raw_statement, index, double); } self.connection.last_error() } pub fn column_double(&self, index: i32) -> Result { let index = index as c_int; let result = unsafe { sqlite3_column_double(self.raw_statement, index) }; self.connection.last_error()?; Ok(result) } pub fn bind_int(&self, index: i32, int: i32) -> Result<()> { // dbg!("bind int", index); let index = index as c_int; unsafe { sqlite3_bind_int(self.raw_statement, index, int); }; self.connection.last_error() } pub fn column_int(&self, index: i32) -> Result { let index = index as c_int; let result = unsafe { sqlite3_column_int(self.raw_statement, index) }; self.connection.last_error()?; Ok(result) } pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> { // dbg!("bind int64", index); let index = index as c_int; unsafe { sqlite3_bind_int64(self.raw_statement, index, int); } self.connection.last_error() } pub fn column_int64(&self, index: i32) -> Result { let index = index as c_int; let result = unsafe { sqlite3_column_int64(self.raw_statement, index) }; self.connection.last_error()?; Ok(result) } pub fn bind_null(&self, index: i32) -> Result<()> { // dbg!("bind null", index); let index = index as c_int; unsafe { sqlite3_bind_null(self.raw_statement, index); } self.connection.last_error() } pub fn bind_text(&self, index: i32, text: &str) -> Result<()> { // dbg!("bind text", index, text); let index = index as c_int; let text_pointer = text.as_ptr() as *const _; let len = text.len() as c_int; unsafe { sqlite3_bind_text( self.raw_statement, index, text_pointer, len, SQLITE_TRANSIENT(), ); } self.connection.last_error() } pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> { let index = index as c_int; let pointer = unsafe { sqlite3_column_text(self.raw_statement, index) }; self.connection.last_error()?; if pointer.is_null() { return Ok(""); } let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize }; self.connection.last_error()?; let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) }; Ok(str::from_utf8(slice)?) } pub fn bind(&self, value: T, index: i32) -> Result { debug_assert!(index > 0); value.bind(self, index) } pub fn column(&mut self) -> Result { let (result, _) = T::column(self, 0)?; Ok(result) } pub fn column_type(&mut self, index: i32) -> Result { let result = unsafe { sqlite3_column_type(self.raw_statement, index) }; // SELECT FROM TABLE self.connection.last_error()?; match result { SQLITE_INTEGER => Ok(SqlType::Integer), SQLITE_FLOAT => Ok(SqlType::Float), SQLITE_TEXT => Ok(SqlType::Text), SQLITE_BLOB => Ok(SqlType::Blob), SQLITE_NULL => Ok(SqlType::Null), _ => Err(anyhow!("Column type returned was incorrect ")), } } pub fn with_bindings(&mut self, bindings: impl Bind) -> Result<&mut Self> { self.bind(bindings, 1)?; Ok(self) } fn step(&mut self) -> Result { unsafe { match sqlite3_step(self.raw_statement) { SQLITE_ROW => Ok(StepResult::Row), SQLITE_DONE => Ok(StepResult::Done), SQLITE_MISUSE => Ok(StepResult::Misuse), other => self .connection .last_error() .map(|_| StepResult::Other(other)), } } } pub fn insert(&mut self) -> Result { self.exec()?; Ok(self.connection.last_insert_id()) } pub fn exec(&mut self) -> Result<()> { fn logic(this: &mut Statement) -> Result<()> { while this.step()? == StepResult::Row {} Ok(()) } let result = logic(self); self.reset(); result } pub fn map(&mut self, callback: impl FnMut(&mut Statement) -> Result) -> Result> { fn logic( this: &mut Statement, mut callback: impl FnMut(&mut Statement) -> Result, ) -> Result> { let mut mapped_rows = Vec::new(); while this.step()? == StepResult::Row { mapped_rows.push(callback(this)?); } Ok(mapped_rows) } let result = logic(self, callback); self.reset(); result } pub fn rows(&mut self) -> Result> { self.map(|s| s.column::()) } pub fn single(&mut self, callback: impl FnOnce(&mut Statement) -> Result) -> Result { fn logic( this: &mut Statement, callback: impl FnOnce(&mut Statement) -> Result, ) -> Result { if this.step()? != StepResult::Row { return Err(anyhow!( "Single(Map) called with query that returns no rows." )); } callback(this) } let result = logic(self, callback); self.reset(); result } pub fn row(&mut self) -> Result { self.single(|this| this.column::()) } pub fn maybe( &mut self, callback: impl FnOnce(&mut Statement) -> Result, ) -> Result> { fn logic( this: &mut Statement, callback: impl FnOnce(&mut Statement) -> Result, ) -> Result> { if this.step()? != StepResult::Row { return Ok(None); } callback(this).map(|r| Some(r)) } let result = logic(self, callback); self.reset(); result } pub fn maybe_row(&mut self) -> Result> { self.maybe(|this| this.column::()) } } impl<'a> Drop for Statement<'a> { fn drop(&mut self) { unsafe { sqlite3_finalize(self.raw_statement) }; } } #[cfg(test)] mod test { use indoc::indoc; use crate::{connection::Connection, statement::StepResult}; #[test] fn blob_round_trips() { let connection1 = Connection::open_memory("blob_round_trips"); connection1 .exec(indoc! {" CREATE TABLE blobs ( data BLOB );"}) .unwrap(); let blob = &[0, 1, 2, 4, 8, 16, 32, 64]; let mut write = connection1 .prepare("INSERT INTO blobs (data) VALUES (?);") .unwrap(); write.bind_blob(1, blob).unwrap(); assert_eq!(write.step().unwrap(), StepResult::Done); // Read the blob from the let connection2 = Connection::open_memory("blob_round_trips"); let mut read = connection2.prepare("SELECT * FROM blobs;").unwrap(); assert_eq!(read.step().unwrap(), StepResult::Row); assert_eq!(read.column_blob(0).unwrap(), blob); assert_eq!(read.step().unwrap(), StepResult::Done); // Delete the added blob and verify its deleted on the other side connection2.exec("DELETE FROM blobs;").unwrap(); let mut read = connection1.prepare("SELECT * FROM blobs;").unwrap(); assert_eq!(read.step().unwrap(), StepResult::Done); } #[test] pub fn maybe_returns_options() { let connection = Connection::open_memory("maybe_returns_options"); connection .exec(indoc! {" CREATE TABLE texts ( text TEXT );"}) .unwrap(); assert!(connection .prepare("SELECT text FROM texts") .unwrap() .maybe_row::() .unwrap() .is_none()); let text_to_insert = "This is a test"; connection .prepare("INSERT INTO texts VALUES (?)") .unwrap() .with_bindings(text_to_insert) .unwrap() .exec() .unwrap(); assert_eq!( connection .prepare("SELECT text FROM texts") .unwrap() .maybe_row::() .unwrap(), Some(text_to_insert.to_string()) ); } }