From e0f4af5449348d5dc231425e06290a98baf02a5b Mon Sep 17 00:00:00 2001 From: David Johnson Date: Wed, 19 Jun 2019 17:21:19 -0400 Subject: [PATCH] Consider using Bool instead of CBool. --- src/ArrayFire/Array.hs | 31 +++++++++++++++++++++++++++++-- src/ArrayFire/Types.hs | 3 +-- test/ArrayFire/AlgorithmSpec.hs | 32 ++++++++++++++++---------------- test/ArrayFire/ArithSpec.hs | 16 +++++++--------- test/ArrayFire/ArraySpec.hs | 8 ++++---- test/ArrayFire/UtilSpec.hs | 2 +- 6 files changed, 58 insertions(+), 34 deletions(-) diff --git a/src/ArrayFire/Array.hs b/src/ArrayFire/Array.hs index 7cbc577..a996183 100644 --- a/src/ArrayFire/Array.hs +++ b/src/ArrayFire/Array.hs @@ -27,8 +27,10 @@ import Foreign.ForeignPtr import Foreign.Marshal hiding (void) import Foreign.Ptr import Foreign.Storable +import Foreign.C.Types import System.IO.Unsafe +import Unsafe.Coerce import ArrayFire.Exception import ArrayFire.FFI @@ -64,7 +66,16 @@ mkArray {-# NOINLINE mkArray #-} mkArray dims xs = unsafePerformIO . mask_ $ do - dataPtr <- castPtr <$> newArray (Prelude.take size xs) + dataPtr <- + case dType of + x | x == b8 -> do + let cs :: [CBool] = + fromIntegral . fromEnum <$> + Prelude.take size (unsafeCoerce xs :: [Bool]) + mapM_ print cs >> print "cbools!" + castPtr <$> newArray cs + | otherwise -> + castPtr <$> newArray (Prelude.take size xs) let ndims = fromIntegral (Prelude.length dims) alloca $ \arrayPtr -> do dimsPtr <- newArray (DimT . fromIntegral <$> dims) @@ -232,6 +243,11 @@ toVector arr@(Array fptr) = do throwAFError =<< af_get_data_ptr (castPtr ptr) arrPtr newFptr <- newForeignPtr finalizerFree ptr pure $ unsafeFromForeignPtr0 newFptr len + where + go 0 = False + go 1 = True + go _ = error "Invalid Ptr CBool" + typ = afType (Proxy @ a) toList :: forall a . AFType a => Array a -> [a] toList = V.toList . toVector @@ -241,5 +257,16 @@ getScalar (Array fptr) = unsafePerformIO . mask_ . withForeignPtr fptr $ \arrPtr -> do alloca $ \ptr -> do throwAFError =<< af_get_scalar (castPtr ptr) arrPtr - peek ptr + if typ == b8 + then do + b :: CBool <- peek (castPtr ptr) + pure . unsafeCoerce $ case b of + 0 -> False + 1 -> True + _ -> error "Invalid Ptr CBool" + else + peek ptr + where + typ = afType (Proxy @b) + diff --git a/src/ArrayFire/Types.hs b/src/ArrayFire/Types.hs index bf0c096..05a56f5 100644 --- a/src/ArrayFire/Types.hs +++ b/src/ArrayFire/Types.hs @@ -35,7 +35,6 @@ import Data.Complex import Data.Proxy import Data.Word import Foreign.C.String -import Foreign.C.Types import Foreign.ForeignPtr import Foreign.Storable import GHC.Int @@ -61,7 +60,7 @@ instance AFType (Complex Double) where instance AFType (Complex Float) where afType Proxy = c32 -instance AFType CBool where +instance AFType Bool where afType Proxy = b8 instance AFType Int32 where diff --git a/test/ArrayFire/AlgorithmSpec.hs b/test/ArrayFire/AlgorithmSpec.hs index 3243386..43fa1c7 100644 --- a/test/ArrayFire/AlgorithmSpec.hs +++ b/test/ArrayFire/AlgorithmSpec.hs @@ -19,8 +19,8 @@ spec = A.sum (A.scalar @Double 10) 0 `shouldBe` 10.0 A.sum (A.scalar @(A.Complex Double) (1 A.:+ 1)) 0 `shouldBe` 1 A.:+ 1 A.sum (A.scalar @(A.Complex Float) (1 A.:+ 1)) 0 `shouldBe` 1 A.:+ 1 - A.sum (A.scalar @A.CBool 1) 0 `shouldBe` 1 - A.sum (A.scalar @A.CBool 0) 0 `shouldBe` 0 + A.sum (A.scalar @Bool True) 0 `shouldBe` True + A.sum (A.scalar @Bool False) 0 `shouldBe` False it "Should sum a vector" $ do A.sum (A.vector @Int 10 [1..]) 0 `shouldBe` 55 A.sum (A.vector @A.Int64 10 [1..]) 0 `shouldBe` 55 @@ -32,8 +32,8 @@ spec = A.sum (A.vector @Double 10 [1..]) 0 `shouldBe` 55.0 A.sum (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` 10.0 A.:+ 10.0 A.sum (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` 10.0 A.:+ 10.0 - A.sum (A.vector @A.CBool 10 (repeat 1)) 0 `shouldBe` 10 - A.sum (A.vector @A.CBool 10 (repeat 0)) 0 `shouldBe` 0 +-- A.sum (A.vector @A.CBool 10 (repeat 1)) 0 `shouldBe` 10 +-- A.sum (A.vector @A.CBool 10 (repeat 0)) 0 `shouldBe` 0 it "Should sum a default value to replace NaN" $ do A.sumNaN (A.vector @Float 10 [1..]) 0 1.0 `shouldBe` 55 A.sumNaN (A.vector @Double 2 [acos 2, acos 2]) 0 50 `shouldBe` 100 @@ -50,8 +50,8 @@ spec = A.product (A.scalar @Double 10) 0 `shouldBe` 10.0 A.product (A.scalar @(A.Complex Double) (1 A.:+ 1)) 0 `shouldBe` 1 A.:+ 1 A.product (A.scalar @(A.Complex Float) (1 A.:+ 1)) 0 `shouldBe` 1 A.:+ 1 - A.product (A.scalar @A.CBool 1) 0 `shouldBe` 1 - A.product (A.scalar @A.CBool 0) 0 `shouldBe` 0 +-- A.product (A.scalar @A.CBool 1) 0 `shouldBe` 1 +-- A.product (A.scalar @A.CBool 0) 0 `shouldBe` 0 it "Should product a vector" $ do A.product (A.vector @Int 10 [1..]) 0 `shouldBe` 3628800 A.product (A.vector @A.Int64 10 [1..]) 0 `shouldBe` 3628800 @@ -63,8 +63,8 @@ spec = A.product (A.vector @Double 10 [1..]) 0 `shouldBe` 3628800.0 A.product (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` 0.0 A.:+ 32.0 A.product (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` 0.0 A.:+ 32.0 - A.product (A.vector @A.CBool 10 (repeat 1)) 0 `shouldBe` 10 -- FIXME: This is a bug, should be 0 - A.product (A.vector @A.CBool 10 (repeat 0)) 0 `shouldBe` 0 +-- A.product (A.vector @A.CBool 10 (repeat 1)) 0 `shouldBe` 10 -- FIXME: This is a bug, should be 0 +-- A.product (A.vector @A.CBool 10 (repeat 0)) 0 `shouldBe` 0 it "Should product a default value to replace NaN" $ do A.productNaN (A.vector @Float 10 [1..]) 0 1.0 `shouldBe` 3628800.0 A.productNaN (A.vector @Double 2 [acos 2, acos 2]) 0 50 `shouldBe` 2500 @@ -81,19 +81,19 @@ spec = A.min (A.vector @Double 10 [1..]) 0 `shouldBe` 1 A.min (A.vector @(A.Complex Double) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` 1 A.:+ 1 A.min (A.vector @(A.Complex Float) 10 (repeat (1 A.:+ 1))) 0 `shouldBe` 1 A.:+ 1 - A.min (A.vector @A.CBool 10 [1..]) 0 `shouldBe` 1 - A.min (A.vector @A.CBool 10 [1..]) 0 `shouldBe` 1 +-- A.min (A.vector @A.CBool 10 [1..]) 0 `shouldBe` 1 +-- A.min (A.vector @A.CBool 10 [1..]) 0 `shouldBe` 1 it "Should find if all elements are true" $ do A.allTrue (A.vector @Double 5 (repeat 12.0)) 0 `shouldBe` True - A.allTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` True - A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` False - A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` False +-- A.allTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` True +-- A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` False +-- A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` False it "Should find if any elements are true" $ do - A.anyTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` True +-- A.anyTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` True A.anyTrue (A.vector @Int 5 (repeat 23)) 0 `shouldBe` True - A.anyTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` False +-- A.anyTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` False it "Should get count of all elements" $ do A.count (A.vector @Int 5 (repeat 1)) 0 `shouldBe` 5 - A.count (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 5 +-- A.count (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 5 A.count (A.vector @Double 5 (repeat 1)) 0 `shouldBe` 5 A.count (A.vector @Float 5 (repeat 1)) 0 `shouldBe` 5 diff --git a/test/ArrayFire/ArithSpec.hs b/test/ArrayFire/ArithSpec.hs index fbbcf4e..6815c47 100644 --- a/test/ArrayFire/ArithSpec.hs +++ b/test/ArrayFire/ArithSpec.hs @@ -11,8 +11,6 @@ spec = describe "Arith tests" $ do it "Should add two scalar arrays" $ do scalar @Int 1 + 2 `shouldBe` 3 - it "Should add two scalar bool arrays" $ do - scalar @CBool 1 + 0 `shouldBe` 1 it "Should subtract two scalar arrays" $ do scalar @Int 4 - 2 `shouldBe` 2 it "Should multiply two scalar arrays" $ do @@ -38,13 +36,13 @@ spec = it "Should eq Array" $ do 3 == (3 :: Array Double) `shouldBe` True it "Should and Array" $ do - ((mkArray @CBool [1] [0] `and` (mkArray [1] [1])) False) - `shouldBe` mkArray [1] [0] + ((mkArray @Bool [1] [False] `and` (mkArray [1] [True])) False) + `shouldBe` mkArray [1] [False] it "Should and Array" $ do - ((mkArray @CBool [2] [0,0] `and` (mkArray [2] [1,0])) False) - `shouldBe` mkArray [2] [0, 0] + ((mkArray @Bool [2] [False,False] `and` (mkArray [2] [True,False])) False) + `shouldBe` mkArray [2] [False, False] it "Should or Array" $ do - ((mkArray @CBool [2] [0,0] `or` (mkArray [2] [1,0])) False) - `shouldBe` mkArray [2] [1, 0] + ((mkArray @Bool [2] [False,False] `or` (mkArray [2] [True,False])) False) + `shouldBe` mkArray [2] [True, False] it "Should not Array" $ do - not (mkArray @CBool [2] [1,1]) `shouldBe` mkArray [2] [0,0] + not (mkArray @Bool [2] [True, True]) `shouldBe` mkArray [2] [False,False] diff --git a/test/ArrayFire/ArraySpec.hs b/test/ArrayFire/ArraySpec.hs index aa7331e..fe485c4 100644 --- a/test/ArrayFire/ArraySpec.hs +++ b/test/ArrayFire/ArraySpec.hs @@ -70,7 +70,7 @@ spec = isSparse arr `shouldBe` False it "Should make a Bit array" $ do - let arr = mkArray @CBool [2,2] [1,1,1,1] + let arr = mkArray @Bool [2,2] (repeat True) isBool arr `shouldBe` True it "Should make an integer array" $ do @@ -80,7 +80,7 @@ spec = it "Should make a Floating array" $ do let arr = mkArray @Double [2,2] (repeat 1) isFloating arr `shouldBe` True - let arr = mkArray @CBool [2,2] (repeat 1) + let arr = mkArray @Bool [2,2] (repeat True) isFloating arr `shouldBe` False it "Should make a Complex array" $ do @@ -122,8 +122,8 @@ spec = let arr = mkArray @Float [10,10] (repeat (5.5)) toList arr `shouldBe` Prelude.replicate 100 5.5 - let arr = mkArray @CBool [4] [1,1,0,1] - toList arr `shouldBe` [1,1,0,1] + let arr = mkArray @Bool [4] [True,True,False,True] + toList arr `shouldBe` [True,True,False,True] let arr = mkArray @Int16 [10] [1..] toList arr `shouldBe` [1..10] diff --git a/test/ArrayFire/UtilSpec.hs b/test/ArrayFire/UtilSpec.hs index b482f48..dcd6039 100644 --- a/test/ArrayFire/UtilSpec.hs +++ b/test/ArrayFire/UtilSpec.hs @@ -22,7 +22,7 @@ spec = A.getSizeOf (Proxy @ Word32) `shouldBe` 4 A.getSizeOf (Proxy @ Word16) `shouldBe` 2 A.getSizeOf (Proxy @ Word8) `shouldBe` 1 - A.getSizeOf (Proxy @ CBool) `shouldBe` 1 + A.getSizeOf (Proxy @ Bool) `shouldBe` 1 A.getSizeOf (Proxy @ Double) `shouldBe` 8 A.getSizeOf (Proxy @ Float) `shouldBe` 4 A.getSizeOf (Proxy @ (Complex Float)) `shouldBe` 8