Skip to content

Commit 1edace0

Browse files
authored
Add Manual Storable Instances for DPIXid and VersionInfo (#26) (#42)
* Added storable instances for DPIXid and VersionInfo * Added storable roundtrip test for VersionInfo and DPIXid type * Fixed example from README
1 parent e0b0eb6 commit 1edace0

File tree

4 files changed

+141
-13
lines changed

4 files changed

+141
-13
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import GHC.Generics (Generic)
2121
main :: IO ()
2222
main = do
2323
let stmt = "select count(*), sysdate, 'ignore next column', 125.24, 3.14 from dual"
24-
conn <- createConn (ConnectionParams "username" "password" "localhost/XEPDB1")
25-
rows <- query @ReturnedRow conn stmt
24+
conn <- connect (ConnectionParams "username" "password" "localhost/XEPDB1" Nothing)
25+
rows <- query_ conn stmt :: IO [ReturnedRow]
2626
print rows
2727

2828
-- [ ReturnedRow { count = RowCount {getRowCount = 1.0}

src/Database/Oracle/Simple/Internal.hs

+23-2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ module Database.Oracle.Simple.Internal
3333
ConnectionParams (..),
3434
OracleError (..),
3535
ErrorInfo (..),
36+
VersionInfo (..),
3637
renderErrorInfo,
3738
ping,
3839
fetch,
@@ -538,8 +539,28 @@ data VersionInfo = VersionInfo
538539
, portUpdateNum :: CInt
539540
, fullVersionNum :: CUInt
540541
}
541-
deriving (Show, Eq, Generic)
542-
deriving anyclass (GStorable)
542+
deriving (Show, Eq)
543+
544+
instance Storable VersionInfo where
545+
sizeOf _ = sizeOf (undefined :: CInt) * 6
546+
alignment _ = alignment (undefined :: CInt)
547+
peek p = do
548+
let basePtr = castPtr p
549+
versionNum <- peekByteOff basePtr 0
550+
releaseNum <- peekByteOff basePtr 4
551+
updateNum <- peekByteOff basePtr 8
552+
portReleaseNum <- peekByteOff basePtr 12
553+
portUpdateNum <- peekByteOff basePtr 16
554+
fullVersionNum <- peekByteOff basePtr 20
555+
return VersionInfo{..}
556+
poke p VersionInfo{..} = do
557+
let basePtr = castPtr p
558+
pokeByteOff basePtr 0 versionNum
559+
pokeByteOff basePtr 4 releaseNum
560+
pokeByteOff basePtr 8 updateNum
561+
pokeByteOff basePtr 12 portReleaseNum
562+
pokeByteOff basePtr 16 portUpdateNum
563+
pokeByteOff basePtr 20 fullVersionNum
543564

544565
getClientVersion ::
545566
IO VersionInfo

src/Database/Oracle/Simple/Transaction.hs

+86-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
{-# LANGUAGE DeriveAnyClass #-}
2-
{-# LANGUAGE DeriveGeneric #-}
31
{-# LANGUAGE DerivingStrategies #-}
42
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
53
{-# LANGUAGE RecordWildCards #-}
64
{-# LANGUAGE ScopedTypeVariables #-}
75
{-# LANGUAGE ViewPatterns #-}
86

97
module Database.Oracle.Simple.Transaction
10-
( beginTransaction,
8+
(
9+
DPIXid (..),
10+
beginTransaction,
1111
commitTransaction,
1212
prepareCommit,
1313
withTransaction,
@@ -19,12 +19,11 @@ import Control.Exception (catch, throw)
1919
import Control.Monad (replicateM, void, when, (<=<))
2020
import Data.UUID (UUID, toString)
2121
import Data.UUID.V4 (nextRandom)
22-
import Foreign (alloca, peek, poke, withForeignPtr)
22+
import Foreign (alloca, withForeignPtr)
2323
import Foreign.C.String (CString, withCStringLen)
2424
import Foreign.C.Types (CInt (CInt), CLong, CUInt (CUInt))
25-
import Foreign.Ptr (Ptr)
26-
import Foreign.Storable.Generic (GStorable)
27-
import GHC.Generics (Generic)
25+
import Foreign.Ptr (Ptr ,castPtr)
26+
import Foreign.Storable (Storable(..))
2827
import System.Random (getStdRandom, uniformR)
2928

3029
import Database.Oracle.Simple.Execute (execute_)
@@ -141,8 +140,86 @@ data DPIXid = DPIXid
141140
, dpixBranchQualifier :: CString
142141
, dpixBranchQualifierLength :: CUInt
143142
}
144-
deriving (Generic, Show)
145-
deriving anyclass (GStorable)
143+
deriving (Show, Eq)
144+
145+
instance Storable DPIXid where
146+
sizeOf _ =
147+
let
148+
-- Sizes of fields
149+
sizeFormatId = sizeOf (undefined :: CLong)
150+
sizeTransactionId = sizeOf (undefined :: CString)
151+
sizeTransactionIdLength = sizeOf (undefined :: CUInt)
152+
sizeQualifier = sizeOf (undefined :: CString)
153+
sizeQualifierLength = sizeOf (undefined :: CUInt)
154+
155+
-- Alignments of fields
156+
alignFormatId = alignment (undefined :: CLong)
157+
alignTransactionId = alignment (undefined :: CString)
158+
alignTransactionIdLength = alignment (undefined :: CUInt)
159+
alignQualifier = alignment (undefined :: CString)
160+
alignQualifierLength = alignment (undefined :: CUInt)
161+
162+
-- Padding for each field
163+
paddingTransactionId = padding sizeFormatId alignTransactionId
164+
paddingTransactionIdLength = padding (sizeTransactionId + paddingTransactionId) alignTransactionIdLength
165+
paddingQualifier = padding (sizeTransactionIdLength + paddingTransactionIdLength) alignQualifier
166+
paddingQualifierLength = padding (sizeQualifier + paddingQualifier) alignQualifierLength
167+
in
168+
sizeFormatId +
169+
paddingTransactionId + sizeTransactionId +
170+
paddingTransactionIdLength + sizeTransactionIdLength +
171+
paddingQualifier + sizeQualifier +
172+
paddingQualifierLength + sizeQualifierLength +
173+
-- Final padding to align the structure itself
174+
padding (sizeFormatId +
175+
paddingTransactionId + sizeTransactionId +
176+
paddingTransactionIdLength + sizeTransactionIdLength +
177+
paddingQualifier + sizeQualifier +
178+
paddingQualifierLength + sizeQualifierLength) alignFormatId
179+
180+
alignment _ = alignment (undefined :: CLong)
181+
182+
peek p = do
183+
let basePtr = castPtr p
184+
formatId <- peekByteOff basePtr 0
185+
186+
let offsetTransactionId = alignedOffset 0 (sizeOf (undefined :: CLong)) (alignment (undefined :: CString))
187+
transactionId <- peekByteOff basePtr offsetTransactionId
188+
189+
let offsetTransactionIdLength = offsetTransactionId + sizeOf (undefined :: CString)
190+
transactionIdLength <- peekByteOff basePtr offsetTransactionIdLength
191+
192+
let offsetQualifier = alignedOffset offsetTransactionIdLength (sizeOf (undefined :: CUInt)) (alignment (undefined :: CString))
193+
qualifier <- peekByteOff basePtr offsetQualifier
194+
195+
let offsetQualifierLength = offsetQualifier + sizeOf (undefined :: CString)
196+
qualifierLength <- peekByteOff basePtr offsetQualifierLength
197+
198+
return $ DPIXid formatId transactionId transactionIdLength qualifier qualifierLength
199+
200+
poke p (DPIXid formatId transactionId transactionIdLength qualifier qualifierLength) = do
201+
let basePtr = castPtr p
202+
pokeByteOff basePtr 0 formatId
203+
204+
let offsetTransactionId = alignedOffset 0 (sizeOf (undefined :: CLong)) (alignment (undefined :: CString))
205+
pokeByteOff basePtr offsetTransactionId transactionId
206+
207+
let offsetTransactionIdLength = offsetTransactionId + sizeOf (undefined :: CString)
208+
pokeByteOff basePtr offsetTransactionIdLength transactionIdLength
209+
210+
let offsetQualifier = alignedOffset offsetTransactionIdLength (sizeOf (undefined :: CUInt)) (alignment (undefined :: CString))
211+
pokeByteOff basePtr offsetQualifier qualifier
212+
213+
let offsetQualifierLength = offsetQualifier + sizeOf (undefined :: CString)
214+
pokeByteOff basePtr offsetQualifierLength qualifierLength
215+
216+
-- Helper to calculate padding between fields
217+
padding :: Int -> Int -> Int
218+
padding size align = (align - size `mod` align) `mod` align
219+
220+
-- Helper to calculate aligned offsets
221+
alignedOffset :: Int -> Int -> Int -> Int
222+
alignedOffset base size align = base + size + padding (base + size) align
146223

147224
withDPIXid :: Transaction -> (Ptr DPIXid -> IO a) -> IO a
148225
withDPIXid Transaction {..} action =

test/Main.hs

+30
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ import qualified Hedgehog.Gen as Gen
2222
import qualified Hedgehog.Range as Range
2323
import Test.Hspec (Spec, around, describe, hspec, it, shouldBe)
2424
import Test.Hspec.Hedgehog (hedgehog)
25+
import Foreign (peek, Storable, with)
26+
import Foreign.C.Types (CLong(..), CUInt(..))
27+
import Foreign.C.String (newCString)
2528

2629
import Database.Oracle.Simple
2730

@@ -242,5 +245,32 @@ spec pool = do
242245
results <- query_ @(Only Int) conn "select * from transactions_test"
243246
void $ execute_ conn "drop table transactions_test"
244247
results `shouldBe` [Only 1 .. Only 8] <> [Only 10]
248+
describe "Storable round trip tests" $ do
249+
it "VersionInfo" $ \_ -> do
250+
let versionInfo = VersionInfo {
251+
versionNum = 1
252+
, releaseNum = 2
253+
, updateNum = 3
254+
, portReleaseNum = 4
255+
, portUpdateNum = 5
256+
, fullVersionNum = 6
257+
}
258+
result <- roundTripStorable versionInfo
259+
result `shouldBe` versionInfo
260+
it "DPIXid" $ \_ -> do
261+
someCString <- newCString "hello"
262+
let dpixid = DPIXid {
263+
dpixFormatId = CLong 1
264+
, dpixGlobalTransactionId = someCString
265+
, dpixGlobalTransactionIdLength = CUInt 2
266+
, dpixBranchQualifier = someCString
267+
, dpixBranchQualifierLength = CUInt 3
268+
}
269+
result <- roundTripStorable dpixid
270+
result `shouldBe` dpixid
245271
where
246272
handleOracleError action = Exc.try @OracleError action >>= either (\_ -> pure ()) (\_ -> pure ())
273+
274+
-- | Round-trip a value through its `Storable` instance.
275+
roundTripStorable :: Storable a => a -> IO a
276+
roundTripStorable x = with x peek

0 commit comments

Comments
 (0)