flaw-sqlite/Flaw/Data/Sqlite.hs
{-|
Module: Flaw.Data.Sqlite
Description: Simple SQLite Haskell interface.
License: MIT
-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_GHC -fno-warn-missing-pattern-synonym-signatures #-}
module Flaw.Data.Sqlite
( SqliteDb()
, SqliteStmt()
, SqliteQuery()
, sqliteDb
, sqliteExec
, sqliteStmt
, sqliteQuery
, sqliteStep
, sqliteFinalStep
, sqliteTransaction
, SqliteData(..)
, sqliteLastInsertRowId
) where
import Control.Exception
import Control.Monad
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import Data.Int
import Data.IORef
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Foreign.C.String
import Foreign.C.Types
import Foreign.Marshal.Alloc
import Foreign.Ptr
import Foreign.Storable
import Flaw.Book
data SqliteDb = SqliteDb
{ sqliteDbPtr :: !(Ptr C_sqlite3)
, sqliteDbSavePointStmtPtr :: !(Ptr C_sqlite3_stmt)
, sqliteDbReleaseStmtPtr :: !(Ptr C_sqlite3_stmt)
, sqliteDbRollbackToStmtPtr :: !(Ptr C_sqlite3_stmt)
}
data SqliteStmt = SqliteStmt
{ sqliteStmtPtr :: !(Ptr C_sqlite3_stmt)
, sqliteStmtDbPtr :: !(Ptr C_sqlite3)
}
newtype SqliteQuery = SqliteQuery SqliteStmt
-- | Open SQLite database.
sqliteDb :: T.Text -> IO (SqliteDb, IO ())
sqliteDb fileName = withSpecialBook $ \bk -> do
-- open db
dbPtr <- book bk $ B.useAsCString (T.encodeUtf8 fileName) $ \fileNamePtr -> alloca $ \dbPtrPtr -> do
r <- sqlite3_open fileNamePtr dbPtrPtr
dbPtr <- peek dbPtrPtr
when (r /= SQLITE_OK) $ do
when (dbPtr /= nullPtr) $ void $ sqlite3_close dbPtr
throwIO $ SqliteOpenError fileName
return (dbPtr, void $ sqlite3_close dbPtr)
-- create transaction statements
let
createStmt str = do
stmtPtr <- alloca $ \stmtPtrPtr -> do
sqliteCheckError dbPtr (== SQLITE_OK) $ withCString str $ \strPtr -> sqlite3_prepare_v2 dbPtr strPtr (-1) stmtPtrPtr nullPtr
peek stmtPtrPtr
book bk $ return (stmtPtr, void $ sqlite3_finalize stmtPtr)
savePointStmtPtr <- createStmt "SAVEPOINT T"
releaseStmtPtr <- createStmt "RELEASE T"
rollbackToStmtPtr <- createStmt "ROLLBACK TO T"
return SqliteDb
{ sqliteDbPtr = dbPtr
, sqliteDbSavePointStmtPtr = savePointStmtPtr
, sqliteDbReleaseStmtPtr = releaseStmtPtr
, sqliteDbRollbackToStmtPtr = rollbackToStmtPtr
}
-- | Execute one-time query.
sqliteExec :: SqliteDb -> T.Text -> IO ()
sqliteExec SqliteDb
{ sqliteDbPtr = dbPtr
} text =
sqliteCheckError dbPtr (== SQLITE_OK) $ B.useAsCString (T.encodeUtf8 text) $ \textPtr -> sqlite3_exec dbPtr textPtr nullFunPtr nullPtr nullPtr
-- | Create SQLite statement.
sqliteStmt :: SqliteDb -> T.Text -> IO (SqliteStmt, IO ())
sqliteStmt SqliteDb
{ sqliteDbPtr = dbPtr
} text = do
stmtPtr <- alloca $ \stmtPtrPtr -> do
sqliteCheckError dbPtr (== SQLITE_OK) $ B.unsafeUseAsCStringLen (T.encodeUtf8 text) $ \(textPtr, textLen) -> sqlite3_prepare_v2 dbPtr textPtr (fromIntegral textLen) stmtPtrPtr nullPtr
peek stmtPtrPtr
return (SqliteStmt
{ sqliteStmtPtr = stmtPtr
, sqliteStmtDbPtr = dbPtr
}, void $ sqlite3_finalize stmtPtr)
-- | Get query object from statement.
-- Just to reset statement afterwards.
sqliteQuery :: SqliteStmt -> (SqliteQuery -> IO a) -> IO a
sqliteQuery stmt@SqliteStmt
{ sqliteStmtPtr = stmtPtr
} action = finally (action (SqliteQuery stmt)) $ do
void $ sqlite3_reset stmtPtr
void $ sqlite3_clear_bindings stmtPtr
-- | Perform query step.
-- Returns True if step succeeded and there's row of data.
-- Returns False if step succeeded, but there's no data anymore.
-- Throws an exception otherwise.
sqliteStep :: SqliteQuery -> IO Bool
sqliteStep (SqliteQuery SqliteStmt
{ sqliteStmtPtr = stmtPtr
, sqliteStmtDbPtr = dbPtr
}) = do
r <- sqlite3_step stmtPtr
case r of
SQLITE_ROW -> return True
SQLITE_DONE -> return False
_ -> throwSqliteError dbPtr
-- | Perform query step, and check that it returned SQLITE_DONE.
sqliteFinalStep :: SqliteQuery -> IO ()
sqliteFinalStep query = do
r <- sqliteStep query
when r $ throwIO SqliteStepNotFinal
-- | Perform SQLite transaction.
sqliteTransaction :: SqliteDb -> (IO () -> IO a) -> IO a
sqliteTransaction SqliteDb
{ sqliteDbPtr = dbPtr
, sqliteDbSavePointStmtPtr = savePointStmtPtr
, sqliteDbReleaseStmtPtr = releaseStmtPtr
, sqliteDbRollbackToStmtPtr = rollbackToStmtPtr
} io = do
-- save point
void $ sqlite3_reset savePointStmtPtr
sqliteCheckError dbPtr (== SQLITE_DONE) $ sqlite3_step savePointStmtPtr
-- commit function
finishedRef <- newIORef False
let
commit = do
-- check that transaction is not finished
finished <- readIORef finishedRef
when finished $ throwIO SqliteTransactionAlreadyFinished
-- commit
void $ sqlite3_reset releaseStmtPtr
sqliteCheckError dbPtr (== SQLITE_DONE) $ sqlite3_step releaseStmtPtr
-- remember
writeIORef finishedRef True
finally (io commit) $ do
-- rollback if not finished
finished <- readIORef finishedRef
unless finished $ do
void $ sqlite3_reset rollbackToStmtPtr
void $ sqlite3_step rollbackToStmtPtr
void $ sqlite3_reset releaseStmtPtr
void $ sqlite3_step releaseStmtPtr
-- | Class of data which could be used in statements.
class SqliteData a where
-- | Bind data into statement.
sqliteBind :: SqliteQuery -> CInt -> a -> IO ()
-- | Get data from query.
sqliteColumn :: SqliteQuery -> CInt -> IO a
instance SqliteData CInt where
sqliteBind (SqliteQuery SqliteStmt
{ sqliteStmtPtr = stmtPtr
, sqliteStmtDbPtr = dbPtr
}) column value = sqliteCheckError dbPtr (== SQLITE_OK) $ sqlite3_bind_int stmtPtr column value
sqliteColumn (SqliteQuery SqliteStmt
{ sqliteStmtPtr = stmtPtr
}) = sqlite3_column_int stmtPtr
instance SqliteData Int64 where
sqliteBind (SqliteQuery SqliteStmt
{ sqliteStmtPtr = stmtPtr
, sqliteStmtDbPtr = dbPtr
}) column value = sqliteCheckError dbPtr (== SQLITE_OK) $ sqlite3_bind_int64 stmtPtr column value
sqliteColumn (SqliteQuery SqliteStmt
{ sqliteStmtPtr = stmtPtr
}) = sqlite3_column_int64 stmtPtr
instance SqliteData B.ByteString where
sqliteBind (SqliteQuery SqliteStmt
{ sqliteStmtPtr = stmtPtr
, sqliteStmtDbPtr = dbPtr
}) column bytes = sqliteCheckError dbPtr (== SQLITE_OK) $ B.unsafeUseAsCStringLen bytes $ \(ptr, len) ->
-- note: we are forcing non-null pointer in case of zero-length bytestring, in order to bind a blob and not a NULL value
sqlite3_bind_blob stmtPtr column (if len > 0 then castPtr ptr else intPtrToPtr 1) (fromIntegral len) $ castPtrToFunPtr $ intPtrToPtr SQLITE_TRANSIENT
sqliteColumn (SqliteQuery SqliteStmt
{ sqliteStmtPtr = stmtPtr
}) column = do
ptr <- sqlite3_column_blob stmtPtr column
len <- sqlite3_column_bytes stmtPtr column
B.packCStringLen (castPtr ptr, fromIntegral len)
instance SqliteData T.Text where
sqliteBind (SqliteQuery SqliteStmt
{ sqliteStmtPtr = stmtPtr
, sqliteStmtDbPtr = dbPtr
}) column text = sqliteCheckError dbPtr (== SQLITE_OK) $ B.unsafeUseAsCStringLen (T.encodeUtf8 text) $ \(ptr, len) ->
-- note: we are forcing non-null pointer in case of zero-length string, in order to bind a string and not a NULL value
sqlite3_bind_text stmtPtr column (if len > 0 then ptr else intPtrToPtr 1) (fromIntegral len) $ castPtrToFunPtr $ intPtrToPtr SQLITE_TRANSIENT
sqliteColumn (SqliteQuery SqliteStmt
{ sqliteStmtPtr = stmtPtr
}) column = do
ptr <- sqlite3_column_text stmtPtr column
len <- sqlite3_column_bytes stmtPtr column
T.decodeUtf8 <$> B.packCStringLen (ptr, fromIntegral len)
sqliteLastInsertRowId :: SqliteDb -> IO Int64
sqliteLastInsertRowId SqliteDb
{ sqliteDbPtr = dbPtr
} = sqlite3_last_insert_rowid dbPtr
throwSqliteError :: Ptr C_sqlite3 -> IO a
throwSqliteError dbPtr = do
errCode <- sqlite3_errcode dbPtr
errMsgPtr <- sqlite3_errmsg dbPtr
errMsgBytes <- B.packCString errMsgPtr
throwIO $ SqliteError (fromIntegral errCode) (T.decodeUtf8 errMsgBytes)
sqliteCheckError :: Ptr C_sqlite3 -> (CInt -> Bool) -> IO CInt -> IO ()
sqliteCheckError dbPtr cond io = do
r <- io
unless (cond r) $ throwSqliteError dbPtr
data SqliteError
= SqliteError {-# UNPACK #-} !Int !T.Text
| SqliteOpenError !T.Text
| SqliteStepNotFinal
| SqliteTransactionAlreadyFinished
deriving Show
instance Exception SqliteError
-- FFI: types
data C_sqlite3
data C_sqlite3_stmt
-- FFI: functions
foreign import ccall safe sqlite3_open :: Ptr CChar -> Ptr (Ptr C_sqlite3) -> IO CInt
foreign import ccall safe sqlite3_close :: Ptr C_sqlite3 -> IO CInt
foreign import ccall safe sqlite3_prepare_v2 :: Ptr C_sqlite3 -> Ptr CChar -> CInt -> Ptr (Ptr C_sqlite3_stmt) -> Ptr (Ptr CChar) -> IO CInt
foreign import ccall unsafe sqlite3_reset :: Ptr C_sqlite3_stmt -> IO CInt
foreign import ccall safe sqlite3_step :: Ptr C_sqlite3_stmt -> IO CInt
foreign import ccall unsafe sqlite3_clear_bindings :: Ptr C_sqlite3_stmt -> IO CInt
foreign import ccall unsafe sqlite3_finalize :: Ptr C_sqlite3_stmt -> IO CInt
foreign import ccall safe sqlite3_exec :: Ptr C_sqlite3 -> Ptr CChar -> FunPtr (Ptr () -> CInt -> Ptr (Ptr CChar) -> Ptr (Ptr CChar) -> IO CInt) -> Ptr () -> Ptr (Ptr CChar) -> IO CInt
foreign import ccall unsafe sqlite3_bind_int :: Ptr C_sqlite3_stmt -> CInt -> CInt -> IO CInt
foreign import ccall unsafe sqlite3_bind_int64 :: Ptr C_sqlite3_stmt -> CInt -> Int64 -> IO CInt
foreign import ccall safe sqlite3_bind_blob :: Ptr C_sqlite3_stmt -> CInt -> Ptr () -> CInt -> FunPtr (Ptr () -> IO ()) -> IO CInt
foreign import ccall unsafe sqlite3_bind_text :: Ptr C_sqlite3_stmt -> CInt -> Ptr CChar -> CInt -> FunPtr (Ptr () -> IO ()) -> IO CInt
--foreign import ccall unsafe sqlite3_bind_null :: Ptr C_sqlite3_stmt -> CInt -> IO CInt
foreign import ccall unsafe sqlite3_column_int :: Ptr C_sqlite3_stmt -> CInt -> IO CInt
foreign import ccall unsafe sqlite3_column_int64 :: Ptr C_sqlite3_stmt -> CInt -> IO Int64
foreign import ccall unsafe sqlite3_column_blob :: Ptr C_sqlite3_stmt -> CInt -> IO (Ptr ())
foreign import ccall unsafe sqlite3_column_bytes :: Ptr C_sqlite3_stmt -> CInt -> IO CInt
foreign import ccall unsafe sqlite3_column_text :: Ptr C_sqlite3_stmt -> CInt -> IO (Ptr CChar)
foreign import ccall unsafe sqlite3_last_insert_rowid :: Ptr C_sqlite3 -> IO Int64
foreign import ccall unsafe sqlite3_errcode :: Ptr C_sqlite3 -> IO CInt
foreign import ccall unsafe sqlite3_errmsg :: Ptr C_sqlite3 -> IO (Ptr CChar)
-- FFI: values
pattern SQLITE_OK = 0
pattern SQLITE_ROW = 100
pattern SQLITE_DONE = 101
pattern SQLITE_TRANSIENT = -1