{-|
Module      : MachineLearning.Model.Measure 
Description : TIR expression data structures
Copyright   : (c) Fabricio Olivetti de Franca, 2022
License     : GPL-3
Maintainer  : fabricio.olivetti@gmail.com
Stability   : experimental
Portability : POSIX

Performance measures for Regression and Classification.
-}
module MachineLearning.Model.Measure 
  ( Measure(..)
  , toMeasure
  , measureAll
  , _rmse
  )
  where

import Data.Semigroup (Sum(..))

import qualified Numeric.LinearAlgebra       as LA
import qualified Numeric.Morpheus.Statistics as Stat
import qualified Data.Vector.Storable        as V

type Vector = LA.Vector Double
-- * Performance measures

-- | A performance measure has a string name and a function that 
-- takes a vector of the true values, a vector of predict values
-- and returns a `Double`.
data Measure = Measure { Measure -> String
_name :: String
                       , Measure -> Vector -> Vector -> Double
_fun  :: Vector -> Vector -> Double -- ^ true values -> predicted values -> measure
                       }

instance Show Measure where
  show :: Measure -> String
show (Measure String
n Vector -> Vector -> Double
_) = String
n

-- | Mean for a vector of doubles
mean :: Vector -> Double
mean :: Vector -> Double
mean Vector
xs = Vector -> Double
forall a. (Storable a, Num a) => Vector a -> a
V.sum Vector
xs Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector -> Int
forall a. Storable a => Vector a -> Int
V.length Vector
xs)
{-# INLINE mean #-}

-- | Variance for a vector of doubles
var :: Vector -> Double
var :: Vector -> Double
var Vector
xs = Double
sum' Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector -> Int
forall a. Storable a => Vector a -> Int
V.length Vector
xs)
  where
    mu :: Double
mu   = Vector -> Double
mean Vector
xs
    sum' :: Double
sum' = (Double -> Double -> Double) -> Double -> Vector -> Double
forall b a. Storable b => (a -> b -> a) -> a -> Vector b -> a
V.foldl (\Double
s Double
x -> Double
s Double -> Double -> Double
forall a. Num a => a -> a -> a
+ (Double
xDouble -> Double -> Double
forall a. Num a => a -> a -> a
-Double
mu)Double -> Int -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
2 :: Int)) Double
0 Vector
xs
{-# INLINE var #-}

-- | generic mean error measure
meanError :: (Vector -> Vector) -- ^ a function to be applied to the error terms (abs, square,...)
          -> Vector             -- ^ target values
          -> Vector             -- ^ fitted values          
          -> Double
meanError :: (Vector -> Vector) -> Vector -> Vector -> Double
meanError Vector -> Vector
op Vector
ys Vector
ysHat = Vector -> Double
mean (Vector -> Double) -> Vector -> Double
forall a b. (a -> b) -> a -> b
$ Vector -> Vector
op (Vector -> Vector) -> Vector -> Vector
forall a b. (a -> b) -> a -> b
$ Vector
ysHat Vector -> Vector -> Vector
forall a. Num a => a -> a -> a
- Vector
ys
{-# INLINE meanError #-}

-- * Common error measures for regression:
-- MSE, MAE, RMSE, NMSE, r^2

-- | Mean Squared Error
mse :: Vector -> Vector -> Double
--mse           = meanError (^(2 :: Int))
mse :: Vector -> Vector -> Double
mse Vector
ys Vector
ysHat = Vector -> Double
mean (Vector -> Double) -> Vector -> Double
forall a b. (a -> b) -> a -> b
$ (Vector
ysHat Vector -> Vector -> Vector
forall a. Num a => a -> a -> a
- Vector
ys) Vector -> Int -> Vector
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
2 :: Int)
{-# INLINE mse #-}

-- | Mean Absolute Error
mae :: Vector -> Vector -> Double
mae :: Vector -> Vector -> Double
mae Vector
ys Vector
ysHat = Vector -> Double
mean (Vector -> Double) -> Vector -> Double
forall a b. (a -> b) -> a -> b
$ Vector -> Vector
forall a. Num a => a -> a
abs (Vector
ysHat Vector -> Vector -> Vector
forall a. Num a => a -> a -> a
- Vector
ys) -- meanError abs
{-# INLINE mae #-}

-- | Normalized Mean Squared Error
nmse :: Vector -> Vector -> Double
nmse :: Vector -> Vector -> Double
nmse Vector
ys Vector
ysHat = Vector -> Vector -> Double
mse Vector
ysHat Vector
ys Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Vector -> Double
var Vector
ys
{-# INLINE nmse #-}

-- | Root of the Mean Squared Error
rmse :: Vector -> Vector -> Double
rmse :: Vector -> Vector -> Double
rmse Vector
ys Vector
ysHat = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Vector -> Vector -> Double
mse Vector
ysHat Vector
ys
{-# INLINE rmse #-}

-- | negate R^2 - minimization metric
rSq :: Vector -> Vector -> Double
rSq :: Vector -> Vector -> Double
rSq Vector
ys Vector
ysHat = Double -> Double
forall a. Num a => a -> a
negate (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
rDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
t)
  where
    ym :: Double
ym      = Vector -> Double
Stat.mean Vector
ys
    t :: Double
t       = Vector -> Double
sumOfSq (Vector -> Double) -> Vector -> Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> Vector -> Vector
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
V.map (\Double
yi -> Double
yi Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
ym) Vector
ys
    r :: Double
r       = Vector -> Double
sumOfSq (Vector -> Double) -> Vector -> Double
forall a b. (a -> b) -> a -> b
$ Vector
ys Vector -> Vector -> Vector
forall a. Num a => a -> a -> a
- Vector
ysHat
    sumOfSq :: Vector -> Double
sumOfSq = (Double -> Double -> Double) -> Double -> Vector -> Double
forall b a. Storable b => (a -> b -> a) -> a -> Vector b -> a
V.foldl (\Double
s Double
di -> Double
s Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
diDouble -> Int -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
2 :: Int)) Double
0
{-# INLINE rSq #-}

-- * Regression measures
_rmse, _mae, _nmse, _r2 :: Measure
_rmse :: Measure
_rmse = String -> (Vector -> Vector -> Double) -> Measure
Measure String
"RMSE" Vector -> Vector -> Double
rmse
_mae :: Measure
_mae  = String -> (Vector -> Vector -> Double) -> Measure
Measure String
"MAE" Vector -> Vector -> Double
mae
_nmse :: Measure
_nmse = String -> (Vector -> Vector -> Double) -> Measure
Measure String
"NMSE" Vector -> Vector -> Double
nmse
_r2 :: Measure
_r2   = String -> (Vector -> Vector -> Double) -> Measure
Measure String
"R^2" Vector -> Vector -> Double
rSq

-- * Classification measures
_accuracy,_recall,_precision,_f1,_logloss :: Measure
_accuracy :: Measure
_accuracy  = String -> (Vector -> Vector -> Double) -> Measure
Measure String
"Accuracy" Vector -> Vector -> Double
accuracy
_recall :: Measure
_recall    = String -> (Vector -> Vector -> Double) -> Measure
Measure String
"Recall" Vector -> Vector -> Double
recall
_precision :: Measure
_precision = String -> (Vector -> Vector -> Double) -> Measure
Measure String
"Precision" Vector -> Vector -> Double
precision
_f1 :: Measure
_f1        = String -> (Vector -> Vector -> Double) -> Measure
Measure String
"F1" Vector -> Vector -> Double
f1
_logloss :: Measure
_logloss   = String -> (Vector -> Vector -> Double) -> Measure
Measure String
"Log-Loss" Vector -> Vector -> Double
logloss

-- | Accuracy: ratio of correct classification
accuracy :: Vector -> Vector -> Double
accuracy :: Vector -> Vector -> Double
accuracy Vector
ys Vector
ysHat = -Double
equalsDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
tot
  where
    ys' :: [Integer]
ys'    = (Double -> Integer) -> [Double] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
round ([Double] -> [Integer]) -> [Double] -> [Integer]
forall a b. (a -> b) -> a -> b
$ Vector -> [Double]
forall a. Storable a => Vector a -> [a]
LA.toList Vector
ys
    ysHat' :: [Integer]
ysHat' = (Double -> Integer) -> [Double] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
round ([Double] -> [Integer]) -> [Double] -> [Integer]
forall a b. (a -> b) -> a -> b
$ Vector -> [Double]
forall a. Storable a => Vector a -> [a]
LA.toList Vector
ysHat
    (Sum Double
equals, Sum Double
tot) = ((Integer, Integer) -> (Sum Double, Sum Double))
-> [(Integer, Integer)] -> (Sum Double, Sum Double)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Integer, Integer) -> (Sum Double, Sum Double)
cmp ([(Integer, Integer)] -> (Sum Double, Sum Double))
-> [(Integer, Integer)] -> (Sum Double, Sum Double)
forall a b. (a -> b) -> a -> b
$ [Integer] -> [Integer] -> [(Integer, Integer)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer]
ysHat' [Integer]
ys'
    cmp :: (Integer, Integer) -> (Sum Double, Sum Double)
    cmp :: (Integer, Integer) -> (Sum Double, Sum Double)
cmp (Integer
yH, Integer
y)
      | Integer
yH Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
y   = (Double -> Sum Double
forall a. a -> Sum a
Sum Double
1, Double -> Sum Double
forall a. a -> Sum a
Sum Double
1)
      | Bool
otherwise = (Double -> Sum Double
forall a. a -> Sum a
Sum Double
0, Double -> Sum Double
forall a. a -> Sum a
Sum Double
1)

-- | Precision: ratio of correct positive classification
precision :: Vector -> Vector -> Double
precision :: Vector -> Vector -> Double
precision Vector
ys Vector
ysHat = Double
equalsDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
tot
  where
    ys' :: [Integer]
ys'    = (Double -> Integer) -> [Double] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
round ([Double] -> [Integer]) -> [Double] -> [Integer]
forall a b. (a -> b) -> a -> b
$ Vector -> [Double]
forall a. Storable a => Vector a -> [a]
LA.toList Vector
ys
    ysHat' :: [Integer]
ysHat' = (Double -> Integer) -> [Double] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
round ([Double] -> [Integer]) -> [Double] -> [Integer]
forall a b. (a -> b) -> a -> b
$ Vector -> [Double]
forall a. Storable a => Vector a -> [a]
LA.toList Vector
ysHat
    (Sum Double
equals, Sum Double
tot) = ((Integer, Integer) -> (Sum Double, Sum Double))
-> [(Integer, Integer)] -> (Sum Double, Sum Double)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Integer, Integer) -> (Sum Double, Sum Double)
cmp ([(Integer, Integer)] -> (Sum Double, Sum Double))
-> [(Integer, Integer)] -> (Sum Double, Sum Double)
forall a b. (a -> b) -> a -> b
$ [Integer] -> [Integer] -> [(Integer, Integer)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer]
ysHat' [Integer]
ys'
    cmp :: (Integer, Integer) -> (Sum Double, Sum Double)
    cmp :: (Integer, Integer) -> (Sum Double, Sum Double)
cmp (Integer
1, Integer
1)  = (Double -> Sum Double
forall a. a -> Sum a
Sum Double
1, Double -> Sum Double
forall a. a -> Sum a
Sum Double
1)
    cmp (Integer
1, Integer
0)  = (Double -> Sum Double
forall a. a -> Sum a
Sum Double
0, Double -> Sum Double
forall a. a -> Sum a
Sum Double
1)
    cmp (Integer
_, Integer
_) = (Double -> Sum Double
forall a. a -> Sum a
Sum Double
0, Double -> Sum Double
forall a. a -> Sum a
Sum Double
0)

-- | Recall: ratio of retrieval of positive labels
recall :: Vector -> Vector -> Double
recall :: Vector -> Vector -> Double
recall Vector
ys Vector
ysHat = Double
equalsDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
tot
  where
    ys' :: [Integer]
ys'    = (Double -> Integer) -> [Double] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
round ([Double] -> [Integer]) -> [Double] -> [Integer]
forall a b. (a -> b) -> a -> b
$ Vector -> [Double]
forall a. Storable a => Vector a -> [a]
LA.toList Vector
ys
    ysHat' :: [Integer]
ysHat' = (Double -> Integer) -> [Double] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
round ([Double] -> [Integer]) -> [Double] -> [Integer]
forall a b. (a -> b) -> a -> b
$ Vector -> [Double]
forall a. Storable a => Vector a -> [a]
LA.toList Vector
ysHat
    (Sum Double
equals, Sum Double
tot) = ((Integer, Integer) -> (Sum Double, Sum Double))
-> [(Integer, Integer)] -> (Sum Double, Sum Double)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Integer, Integer) -> (Sum Double, Sum Double)
cmp ([(Integer, Integer)] -> (Sum Double, Sum Double))
-> [(Integer, Integer)] -> (Sum Double, Sum Double)
forall a b. (a -> b) -> a -> b
$ [Integer] -> [Integer] -> [(Integer, Integer)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer]
ysHat' [Integer]
ys'

    cmp :: (Integer, Integer) -> (Sum Double, Sum Double)
    cmp :: (Integer, Integer) -> (Sum Double, Sum Double)
cmp (Integer
1, Integer
1)  = (Double -> Sum Double
forall a. a -> Sum a
Sum Double
1, Double -> Sum Double
forall a. a -> Sum a
Sum Double
1)
    cmp (Integer
0, Integer
1)  = (Double -> Sum Double
forall a. a -> Sum a
Sum Double
0, Double -> Sum Double
forall a. a -> Sum a
Sum Double
1)
    cmp (Integer
_, Integer
_) = (Double -> Sum Double
forall a. a -> Sum a
Sum Double
0, Double -> Sum Double
forall a. a -> Sum a
Sum Double
0)

-- | Harmonic average between Precision and Recall
f1 :: Vector -> Vector -> Double
f1 :: Vector -> Vector -> Double
f1 Vector
ys Vector
ysHat = Double
2Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
precDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
recDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/(Double
precDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
rec)
  where
    prec :: Double
prec = Vector -> Vector -> Double
precision Vector
ysHat Vector
ys
    rec :: Double
rec  = Vector -> Vector -> Double
recall Vector
ysHat Vector
ys

-- | LogLoss of a classifier that returns a probability.
logloss :: Vector -> Vector -> Double
logloss :: Vector -> Vector -> Double
logloss Vector
ys Vector
ysHat = Vector -> Double
mean (Vector -> Double) -> Vector -> Double
forall a b. (a -> b) -> a -> b
$ -(Vector
ys Vector -> Vector -> Vector
forall a. Num a => a -> a -> a
* Vector -> Vector
forall a. Floating a => a -> a
log Vector
ysHat' Vector -> Vector -> Vector
forall a. Num a => a -> a -> a
+ (Vector
1 Vector -> Vector -> Vector
forall a. Num a => a -> a -> a
- Vector
ys)Vector -> Vector -> Vector
forall a. Num a => a -> a -> a
*Vector -> Vector
forall a. Floating a => a -> a
log(Vector
1 Vector -> Vector -> Vector
forall a. Num a => a -> a -> a
- Vector
ysHat'))
  where
    ysHat' :: Vector
ysHat' = (Double -> Double) -> Vector -> Vector
forall b (c :: * -> *) e.
(Element b, Container c e) =>
(e -> b) -> c e -> c b
LA.cmap (Double -> Double -> Double
forall a. Ord a => a -> a -> a
min (Double
1.0 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1e-15) (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double -> Double
forall a. Ord a => a -> a -> a
max Double
1e-15) Vector
ysHat



-- | List of all measures
measureAll :: [Measure]
measureAll :: [Measure]
measureAll = [Measure
_rmse, Measure
_mae, Measure
_nmse, Measure
_r2
             , Measure
_accuracy, Measure
_recall, Measure
_precision, Measure
_f1, Measure
_logloss
             ]

-- | Read a string into a measure
toMeasure :: String -> Measure
toMeasure :: String -> Measure
toMeasure String
input
  | [(Bool, Measure)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Bool, Measure)]
cmp  = String -> Measure
forall a. HasCallStack => String -> a
error (String
"Invalid measure: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
input)
  | Bool
otherwise = ((Bool, Measure) -> Measure
forall a b. (a, b) -> b
snd((Bool, Measure) -> Measure)
-> ([(Bool, Measure)] -> (Bool, Measure))
-> [(Bool, Measure)]
-> Measure
forall b c a. (b -> c) -> (a -> b) -> a -> c
.[(Bool, Measure)] -> (Bool, Measure)
forall a. [a] -> a
head) [(Bool, Measure)]
cmp
  where
    cmp :: [(Bool, Measure)]
cmp                       = ((Bool, Measure) -> Bool) -> [(Bool, Measure)] -> [(Bool, Measure)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, Measure) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, Measure)] -> [(Bool, Measure)])
-> [(Bool, Measure)] -> [(Bool, Measure)]
forall a b. (a -> b) -> a -> b
$ (Measure -> (Bool, Measure)) -> [Measure] -> [(Bool, Measure)]
forall a b. (a -> b) -> [a] -> [b]
map Measure -> (Bool, Measure)
isThis [Measure]
measureAll
    isThis :: Measure -> (Bool, Measure)
isThis m :: Measure
m@(Measure String
name Vector -> Vector -> Double
_) = (String
name String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
input, Measure
m)