Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector indexing and insertion operations #509

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/AST.hs
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ data PrimFun sig where

-- local array operators
HugoPeters1024 marked this conversation as resolved.
Show resolved Hide resolved
PrimVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a)
PrimVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, (i, a)) -> Vec n a)

-- general conversion between types
PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b)
Expand Down
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/Analysis/Hash.hs
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq")
encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeSingleType a
encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeSingleType a
encodePrimFun (PrimVectorIndex (VectorType _ a) b) = intHost $(hashQ "PrimVectorIndex") <> encodeSingleType a <> encodeNumType (IntegralNumType b)
encodePrimFun (PrimVectorWrite (VectorType _ a) b) = intHost $(hashQ "PrimVectorWrite") <> encodeSingleType a <> encodeNumType (IntegralNumType b)
encodePrimFun (PrimFromIntegral a b) = intHost $(hashQ "PrimFromIntegral") <> encodeIntegralType a <> encodeNumType b
encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b
encodePrimFun PrimLAnd = intHost $(hashQ "PrimLAnd")
Expand Down
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/Classes/Vector.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import Data.Primitive.Vec
instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where
type IndexType (Exp (Vec n a)) = Exp Int
vecIndex = mkVectorIndex
vecWrite = mkVectorWrite
vecEmpty = undef


4 changes: 4 additions & 0 deletions src/Data/Array/Accelerate/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,7 @@ evalPrim PrimLAnd = evalLAnd
evalPrim PrimLOr = evalLOr
evalPrim PrimLNot = evalLNot
evalPrim (PrimVectorIndex v i) = evalVectorIndex v i
evalPrim (PrimVectorWrite v i) = evalVectorWrite v i
evalPrim (PrimFromIntegral ta tb) = evalFromIntegral ta tb
evalPrim (PrimToFloating ta tb) = evalToFloating ta tb

Expand Down Expand Up @@ -1174,6 +1175,9 @@ evalLNot = fromBool . not . toBool
evalVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, i) -> a
evalVectorIndex (VectorType n _) ti (v, i) | IntegralDict <- integralDict ti = vecIndex v (fromIntegral i)

evalVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, (i, a)) -> Vec n a
evalVectorWrite (VectorType n _) ti (v, (i, a)) | IntegralDict <- integralDict ti = vecWrite v (fromIntegral i) a

evalFromIntegral :: IntegralType a -> NumType b -> a -> b
evalFromIntegral ta (IntegralNumType tb)
| IntegralDict <- integralDict ta
Expand Down
9 changes: 9 additions & 0 deletions src/Data/Array/Accelerate/Smart.hs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ module Data.Array.Accelerate.Smart (
-- ** Smart constructors for vector operations
mkVectorCreate,
mkVectorIndex,
mkVectorWrite,

-- ** Auxiliary functions
($$), ($$$), ($$$$), ($$$$$),
Expand Down Expand Up @@ -1190,6 +1191,11 @@ mkVectorIndex = let n :: Int
n = fromIntegral $ natVal $ Proxy @n
in mkPrimBinary $ PrimVectorIndex @n (VectorType n singleType) integralType

mkVectorWrite :: forall n a. (KnownNat n, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a -> Exp (Vec n a)
mkVectorWrite = let n :: Int
n = fromIntegral $ natVal $ Proxy @n
in mkPrimTernary $ PrimVectorWrite @n (VectorType n singleType) integralType
HugoPeters1024 marked this conversation as resolved.
Show resolved Hide resolved

-- Numeric conversions

mkFromIntegral :: (Elt a, Elt b, IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b
Expand Down Expand Up @@ -1277,6 +1283,9 @@ mkPrimUnary prim (Exp a) = mkExp $ PrimApp prim a
mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltR a, EltR b) -> EltR c) -> Exp a -> Exp b -> Exp c
mkPrimBinary prim (Exp a) (Exp b) = mkExp $ PrimApp prim (SmartExp $ Pair a b)

mkPrimTernary :: (Elt a, Elt b, Elt c, Elt d) => PrimFun ((EltR a, (EltR b, EltR c)) -> EltR d) -> Exp a -> Exp b -> Exp c -> Exp d
mkPrimTernary prim (Exp a) (Exp b) (Exp c) = mkExp $ PrimApp prim (SmartExp $ Pair a (SmartExp (Pair b c)))

mkPrimUnaryBool :: Elt a => PrimFun (EltR a -> PrimBool) -> Exp a -> Exp Bool
mkPrimUnaryBool = mkCoerce @PrimBool $$ mkPrimUnary

Expand Down
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/Trafo/Algebra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ evalPrimApp env f x
PrimMax ty -> evalMax ty x env
PrimMin ty -> evalMin ty x env
PrimVectorIndex _ _ -> Nothing
PrimVectorWrite _ _ -> Nothing
PrimLAnd -> evalLAnd x env
PrimLOr -> evalLOr x env
PrimLNot -> evalLNot x env
Expand Down
10 changes: 10 additions & 0 deletions src/Data/Primitive/Vec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TupleSections #-}
HugoPeters1024 marked this conversation as resolved.
Show resolved Hide resolved
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module : Data.Primitive.Vec
Expand Down Expand Up @@ -96,6 +97,7 @@ data Vec (n :: Nat) a = Vec ByteArray#
class Vectoring vector a | vector -> a where
type IndexType vector :: Data.Kind.Type
HugoPeters1024 marked this conversation as resolved.
Show resolved Hide resolved
vecIndex :: vector -> IndexType vector -> a
vecWrite :: vector -> IndexType vector -> a -> vector
vecEmpty :: vector

instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where
Expand All @@ -104,6 +106,14 @@ instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where
n :: Int
n = fromIntegral $ natVal $ Proxy @n
in if i >= 0 && i < n then indexByteArray# ba# iu# else error ("index " <> show i <> " out of range in Vec of size " <> show n)
vecWrite vec@(Vec ba#) i@(I# iu#) v = runST $ do
let n :: Int
n = fromIntegral $ natVal $ Proxy @n
mba <- newByteArray (n * sizeOf (undefined :: a))
let new_vs = zipWith (\i' v' -> if i' == i then v else v') [0..n] (listOfVec vec)
zipWithM_ (writeByteArray mba) [0..n] new_vs
ByteArray nba# <- unsafeFreezeByteArray mba
return $! Vec nba#
vecEmpty = mkVec


Expand Down