{-# LANGUAGE
MultiParamTypeClasses, FlexibleInstances, FlexibleContexts,
UndecidableInstances, ForeignFunctionInterface, BangPatterns,
RankNTypes
#-}
{-# OPTIONS_GHC -fno-warn-type-defaults #-}
module Data.Random.Distribution.Normal
( Normal(..)
, normal, normalT
, stdNormal, stdNormalT
, doubleStdNormal
, floatStdNormal
, realFloatStdNormal
, normalTail
, normalPair
, boxMullerNormalPair
, knuthPolarNormalPair
) where
import Data.Random.Internal.Words
import Data.Bits
import Data.Random.Source
import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Data.Random.Distribution.Ziggurat
import Data.Random.RVar
import Data.Vector.Generic (Vector)
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as UV
import Data.Number.Erf
normalPair :: (Floating a, Distribution StdUniform a) => RVar (a,a)
normalPair :: RVar (a, a)
normalPair = RVar (a, a)
forall a. (Floating a, Distribution StdUniform a) => RVar (a, a)
boxMullerNormalPair
{-# INLINE boxMullerNormalPair #-}
boxMullerNormalPair :: (Floating a, Distribution StdUniform a) => RVar (a,a)
boxMullerNormalPair :: RVar (a, a)
boxMullerNormalPair = do
a
u <- RVar a
forall a. Distribution StdUniform a => RVar a
stdUniform
a
t <- RVar a
forall a. Distribution StdUniform a => RVar a
stdUniform
let r :: a
r = a -> a
forall a. Floating a => a -> a
sqrt (-2 a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
log a
u)
theta :: a
theta = (2 a -> a -> a
forall a. Num a => a -> a -> a
* a
forall a. Floating a => a
pi) a -> a -> a
forall a. Num a => a -> a -> a
* a
t
x :: a
x = a
r a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
cos a
theta
y :: a
y = a
r a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
sin a
theta
(a, a) -> RVar (a, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x,a
y)
{-# INLINE knuthPolarNormalPair #-}
knuthPolarNormalPair :: (Floating a, Ord a, Distribution Uniform a) => RVar (a,a)
knuthPolarNormalPair :: RVar (a, a)
knuthPolarNormalPair = do
a
v1 <- a -> a -> RVar a
forall a. Distribution Uniform a => a -> a -> RVar a
uniform (-1) 1
a
v2 <- a -> a -> RVar a
forall a. Distribution Uniform a => a -> a -> RVar a
uniform (-1) 1
let s :: a
s = a
v1a -> a -> a
forall a. Num a => a -> a -> a
*a
v1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
v2a -> a -> a
forall a. Num a => a -> a -> a
*a
v2
if a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= 1
then RVar (a, a)
forall a.
(Floating a, Ord a, Distribution Uniform a) =>
RVar (a, a)
knuthPolarNormalPair
else (a, a) -> RVar (a, a)
forall (m :: * -> *) a. Monad m => a -> m a
return ((a, a) -> RVar (a, a)) -> (a, a) -> RVar (a, a)
forall a b. (a -> b) -> a -> b
$ if a
s a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== 0
then (0,0)
else let scale :: a
scale = a -> a
forall a. Floating a => a -> a
sqrt (-2 a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
log a
s a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
s)
in (a
v1 a -> a -> a
forall a. Num a => a -> a -> a
* a
scale, a
v2 a -> a -> a
forall a. Num a => a -> a -> a
* a
scale)
{-# INLINE normalTail #-}
normalTail :: (Distribution StdUniform a, Floating a, Ord a) =>
a -> RVarT m a
normalTail :: a -> RVarT m a
normalTail r :: a
r = RVarT m a
forall (m :: * -> *). RVarT m a
go
where
go :: RVarT m a
go = do
!a
u <- RVarT m a
forall a (m :: * -> *). Distribution StdUniform a => RVarT m a
stdUniformT
let !x :: a
x = a -> a
forall a. Floating a => a -> a
log a
u a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
r
!a
v <- RVarT m a
forall a (m :: * -> *). Distribution StdUniform a => RVarT m a
stdUniformT
let !y :: a
y = a -> a
forall a. Floating a => a -> a
log a
v
if a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
ya -> a -> a
forall a. Num a => a -> a -> a
+a
y a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> 0
then RVarT m a
go
else a -> RVarT m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a
r a -> a -> a
forall a. Num a => a -> a -> a
- a
x)
normalZ ::
(RealFloat a, Erf a, Vector v a, Distribution Uniform a, Integral b) =>
b -> (forall m. RVarT m (Int, a)) -> Ziggurat v a
normalZ :: b -> (forall (m :: * -> *). RVarT m (Int, a)) -> Ziggurat v a
normalZ p :: b
p = Bool
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> a
-> Int
-> (forall (m :: * -> *). RVarT m (Int, a))
-> Ziggurat v a
forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> (t -> t)
-> t
-> Int
-> (forall (m :: * -> *). RVarT m (Int, t))
-> Ziggurat v t
mkZigguratRec Bool
True a -> a
forall a. (Floating a, Ord a) => a -> a
normalF a -> a
forall a. Floating a => a -> a
normalFInv a -> a
forall a. (Floating a, Erf a, Ord a) => a -> a
normalFInt a
forall a. Floating a => a
normalFVol (2Int -> b -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^b
p)
normalF :: (Floating a, Ord a) => a -> a
normalF :: a -> a
normalF x :: a
x
| a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= 0 = 1
| Bool
otherwise = a -> a
forall a. Floating a => a -> a
exp ((-0.5) a -> a -> a
forall a. Num a => a -> a -> a
* a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x)
normalFInv :: Floating a => a -> a
normalFInv :: a -> a
normalFInv y :: a
y = a -> a
forall a. Floating a => a -> a
sqrt ((-2) a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
log a
y)
normalFInt :: (Floating a, Erf a, Ord a) => a -> a
normalFInt :: a -> a
normalFInt x :: a
x
| a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= 0 = 0
| Bool
otherwise = a
forall a. Floating a => a
normalFVol a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Erf a => a -> a
erf (a
x a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
sqrt 0.5)
normalFVol :: Floating a => a
normalFVol :: a
normalFVol = a -> a
forall a. Floating a => a -> a
sqrt (0.5 a -> a -> a
forall a. Num a => a -> a -> a
* a
forall a. Floating a => a
pi)
realFloatStdNormal :: (RealFloat a, Erf a, Distribution Uniform a) => RVarT m a
realFloatStdNormal :: RVarT m a
realFloatStdNormal = Ziggurat Vector a -> RVarT m a
forall a (v :: * -> *) (m :: * -> *).
(Num a, Ord a, Vector v a) =>
Ziggurat v a -> RVarT m a
runZiggurat (Int
-> (forall (m :: * -> *). RVarT m (Int, a)) -> Ziggurat Vector a
forall a (v :: * -> *) b.
(RealFloat a, Erf a, Vector v a, Distribution Uniform a,
Integral b) =>
b -> (forall (m :: * -> *). RVarT m (Int, a)) -> Ziggurat v a
normalZ Int
p forall a (m :: * -> *).
(Num a, Distribution Uniform a) =>
RVarT m (Int, a)
forall (m :: * -> *). RVarT m (Int, a)
getIU Ziggurat Vector a -> Ziggurat Vector a -> Ziggurat Vector a
forall a. a -> a -> a
`asTypeOf` (forall a. Ziggurat Vector a
forall a. HasCallStack => a
undefined :: Ziggurat V.Vector a))
where
p :: Int
p :: Int
p = 6
getIU :: (Num a, Distribution Uniform a) => RVarT m (Int, a)
getIU :: RVarT m (Int, a)
getIU = do
Word8
i <- RVarT m Word8
forall (m :: * -> *). MonadRandom m => m Word8
getRandomWord8
a
u <- a -> a -> RVarT m a
forall a (m :: * -> *).
Distribution Uniform a =>
a -> a -> RVarT m a
uniformT (-1) 1
(Int, a) -> RVarT m (Int, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (2Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
-1), a
u)
doubleStdNormal :: RVarT m Double
doubleStdNormal :: RVarT m Double
doubleStdNormal = Ziggurat Vector Double -> RVarT m Double
forall a (v :: * -> *) (m :: * -> *).
(Num a, Ord a, Vector v a) =>
Ziggurat v a -> RVarT m a
runZiggurat Ziggurat Vector Double
doubleStdNormalZ
doubleStdNormalC :: Int
doubleStdNormalC :: Int
doubleStdNormalC = 512
doubleStdNormalR, doubleStdNormalV :: Double
doubleStdNormalR :: Double
doubleStdNormalR = 3.852046150368388
doubleStdNormalV :: Double
doubleStdNormalV = 2.4567663515413507e-3
{-# NOINLINE doubleStdNormalZ #-}
doubleStdNormalZ :: Ziggurat UV.Vector Double
doubleStdNormalZ :: Ziggurat Vector Double
doubleStdNormalZ = Bool
-> (Double -> Double)
-> (Double -> Double)
-> Int
-> Double
-> Double
-> (forall (m :: * -> *). RVarT m (Int, Double))
-> (forall (m :: * -> *). RVarT m Double)
-> Ziggurat Vector Double
forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> Int
-> t
-> t
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). RVarT m t)
-> Ziggurat v t
mkZiggurat_ Bool
True
Double -> Double
forall a. (Floating a, Ord a) => a -> a
normalF Double -> Double
forall a. Floating a => a -> a
normalFInv
Int
doubleStdNormalC Double
doubleStdNormalR Double
doubleStdNormalV
forall (m :: * -> *). RVarT m (Int, Double)
getIU
(Double -> RVarT m Double
forall a (m :: * -> *).
(Distribution StdUniform a, Floating a, Ord a) =>
a -> RVarT m a
normalTail Double
doubleStdNormalR)
where
getIU :: RVarT m (Int, Double)
getIU :: RVarT m (Int, Double)
getIU = do
!Word64
w <- RVarT m Word64
forall (m :: * -> *). MonadRandom m => m Word64
getRandomWord64
let (u :: Double
u,i :: Word64
i) = Word64 -> (Double, Word64)
wordToDoubleWithExcess Word64
w
(Int, Double) -> RVarT m (Int, Double)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Int, Double) -> RVarT m (Int, Double))
-> (Int, Double) -> RVarT m (Int, Double)
forall a b. (a -> b) -> a -> b
$! (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (Int
doubleStdNormalCInt -> Int -> Int
forall a. Num a => a -> a -> a
-1), Double
uDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
uDouble -> Double -> Double
forall a. Num a => a -> a -> a
-1)
floatStdNormal :: RVarT m Float
floatStdNormal :: RVarT m Float
floatStdNormal = Ziggurat Vector Float -> RVarT m Float
forall a (v :: * -> *) (m :: * -> *).
(Num a, Ord a, Vector v a) =>
Ziggurat v a -> RVarT m a
runZiggurat Ziggurat Vector Float
floatStdNormalZ
floatStdNormalC :: Int
floatStdNormalC :: Int
floatStdNormalC = 512
floatStdNormalR, floatStdNormalV :: Float
floatStdNormalR :: Float
floatStdNormalR = 3.852046150368388
floatStdNormalV :: Float
floatStdNormalV = 2.4567663515413507e-3
{-# NOINLINE floatStdNormalZ #-}
floatStdNormalZ :: Ziggurat UV.Vector Float
floatStdNormalZ :: Ziggurat Vector Float
floatStdNormalZ = Bool
-> (Float -> Float)
-> (Float -> Float)
-> Int
-> Float
-> Float
-> (forall (m :: * -> *). RVarT m (Int, Float))
-> (forall (m :: * -> *). RVarT m Float)
-> Ziggurat Vector Float
forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> Int
-> t
-> t
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). RVarT m t)
-> Ziggurat v t
mkZiggurat_ Bool
True
Float -> Float
forall a. (Floating a, Ord a) => a -> a
normalF Float -> Float
forall a. Floating a => a -> a
normalFInv
Int
floatStdNormalC Float
floatStdNormalR Float
floatStdNormalV
forall (m :: * -> *). RVarT m (Int, Float)
getIU
(Float -> RVarT m Float
forall a (m :: * -> *).
(Distribution StdUniform a, Floating a, Ord a) =>
a -> RVarT m a
normalTail Float
floatStdNormalR)
where
getIU :: RVarT m (Int, Float)
getIU :: RVarT m (Int, Float)
getIU = do
!Word32
w <- RVarT m Word32
forall (m :: * -> *). MonadRandom m => m Word32
getRandomWord32
let (u :: Float
u,i :: Word32
i) = Word32 -> (Float, Word32)
word32ToFloatWithExcess Word32
w
(Int, Float) -> RVarT m (Int, Float)
forall (m :: * -> *) a. Monad m => a -> m a
return (Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (Int
floatStdNormalCInt -> Int -> Int
forall a. Num a => a -> a -> a
-1), Float
uFloat -> Float -> Float
forall a. Num a => a -> a -> a
+Float
uFloat -> Float -> Float
forall a. Num a => a -> a -> a
-1)
normalCdf :: (Real a) => a -> a -> a -> Double
normalCdf :: a -> a -> a -> Double
normalCdf m :: a
m s :: a
s x :: a
x = Double -> Double
forall a. Erf a => a -> a
normcdf ((a -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
x Double -> Double -> Double
forall a. Num a => a -> a -> a
- a -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
m) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ a -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
s)
normalPdf :: (Real a, Floating b) => a -> a -> a -> b
normalPdf :: a -> a -> a -> b
normalPdf mu :: a
mu sigma :: a
sigma x :: a
x =
(b -> b
forall a. Fractional a => a -> a
recip (b -> b
forall a. Floating a => a -> a
sqrt (2 b -> b -> b
forall a. Num a => a -> a -> a
* b
forall a. Floating a => a
pi b -> b -> b
forall a. Num a => a -> a -> a
* b
sigma2))) b -> b -> b
forall a. Num a => a -> a -> a
* (b -> b
forall a. Floating a => a -> a
exp ((-((a -> b
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
x) b -> b -> b
forall a. Num a => a -> a -> a
- (a -> b
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
mu))b -> Integer -> b
forall a b. (Num a, Integral b) => a -> b -> a
^2) b -> b -> b
forall a. Fractional a => a -> a -> a
/ (2 b -> b -> b
forall a. Num a => a -> a -> a
* b
sigma2)))
where
sigma2 :: b
sigma2 = a -> b
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
sigmab -> Integer -> b
forall a b. (Num a, Integral b) => a -> b -> a
^2
normalLogPdf :: (Real a, Floating b) => a -> a -> a -> b
normalLogPdf :: a -> a -> a -> b
normalLogPdf mu :: a
mu sigma :: a
sigma x :: a
x =
b -> b
forall a. Floating a => a -> a
log (b -> b
forall a. Fractional a => a -> a
recip (b -> b
forall a. Floating a => a -> a
sqrt (2 b -> b -> b
forall a. Num a => a -> a -> a
* b
forall a. Floating a => a
pi b -> b -> b
forall a. Num a => a -> a -> a
* b
sigma2))) b -> b -> b
forall a. Num a => a -> a -> a
+
((-((a -> b
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
x) b -> b -> b
forall a. Num a => a -> a -> a
- (a -> b
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
mu))b -> Integer -> b
forall a b. (Num a, Integral b) => a -> b -> a
^2) b -> b -> b
forall a. Fractional a => a -> a -> a
/ (2 b -> b -> b
forall a. Num a => a -> a -> a
* b
sigma2))
where
sigma2 :: b
sigma2 = a -> b
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
sigmab -> Integer -> b
forall a b. (Num a, Integral b) => a -> b -> a
^2
data Normal a
= StdNormal
| Normal a a
instance Distribution Normal Double where
rvarT :: Normal Double -> RVarT n Double
rvarT StdNormal = RVarT n Double
forall (m :: * -> *). RVarT m Double
doubleStdNormal
rvarT (Normal m :: Double
m s :: Double
s) = do
Double
x <- RVarT n Double
forall (m :: * -> *). RVarT m Double
doubleStdNormal
Double -> RVarT n Double
forall (m :: * -> *) a. Monad m => a -> m a
return (Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
s Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
m)
instance Distribution Normal Float where
rvarT :: Normal Float -> RVarT n Float
rvarT StdNormal = RVarT n Float
forall (m :: * -> *). RVarT m Float
floatStdNormal
rvarT (Normal m :: Float
m s :: Float
s) = do
Float
x <- RVarT n Float
forall (m :: * -> *). RVarT m Float
floatStdNormal
Float -> RVarT n Float
forall (m :: * -> *) a. Monad m => a -> m a
return (Float
x Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
s Float -> Float -> Float
forall a. Num a => a -> a -> a
+ Float
m)
instance (Real a, Distribution Normal a) => CDF Normal a where
cdf :: Normal a -> a -> Double
cdf StdNormal = a -> a -> a -> Double
forall a. Real a => a -> a -> a -> Double
normalCdf 0 1
cdf (Normal m :: a
m s :: a
s) = a -> a -> a -> Double
forall a. Real a => a -> a -> a -> Double
normalCdf a
m a
s
instance (Real a, Floating a, Distribution Normal a) => PDF Normal a where
pdf :: Normal a -> a -> Double
pdf StdNormal = a -> a -> a -> Double
forall a b. (Real a, Floating b) => a -> a -> a -> b
normalPdf 0 1
pdf (Normal m :: a
m s :: a
s) = a -> a -> a -> Double
forall a b. (Real a, Floating b) => a -> a -> a -> b
normalPdf a
m a
s
logPdf :: Normal a -> a -> Double
logPdf StdNormal = a -> a -> a -> Double
forall a b. (Real a, Floating b) => a -> a -> a -> b
normalLogPdf 0 1
logPdf (Normal m :: a
m s :: a
s) = a -> a -> a -> Double
forall a b. (Real a, Floating b) => a -> a -> a -> b
normalLogPdf a
m a
s
{-# SPECIALIZE stdNormal :: RVar Double #-}
{-# SPECIALIZE stdNormal :: RVar Float #-}
stdNormal :: Distribution Normal a => RVar a
stdNormal :: RVar a
stdNormal = Normal a -> RVar a
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar Normal a
forall a. Normal a
StdNormal
stdNormalT :: Distribution Normal a => RVarT m a
stdNormalT :: RVarT m a
stdNormalT = Normal a -> RVarT m a
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT Normal a
forall a. Normal a
StdNormal
normal :: Distribution Normal a => a -> a -> RVar a
normal :: a -> a -> RVar a
normal m :: a
m s :: a
s = Normal a -> RVar a
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar (a -> a -> Normal a
forall a. a -> a -> Normal a
Normal a
m a
s)
normalT :: Distribution Normal a => a -> a -> RVarT m a
normalT :: a -> a -> RVarT m a
normalT m :: a
m s :: a
s = Normal a -> RVarT m a
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (a -> a -> Normal a
forall a. a -> a -> Normal a
Normal a
m a
s)