{-# language FlexibleInstances #-}
{-|
Module      : MachineLearning.Model.Fitness 
Description : TIR expression data structures
Copyright   : (c) Fabricio Olivetti de Franca, 2022
License     : GPL-3
Maintainer  : fabricio.olivetti@gmail.com
Stability   : experimental
Portability : POSIX

This module exports the functions that calculates the coefficients,
evaluate the fitness, penalty and constraints of a TIR expression.
-}
module MachineLearning.Model.Fitness 
  ( evalTrain
  , evalTest
  , selectValidTerms
  ) where

import Data.Bifunctor
import Data.Maybe            (fromJust)
import Data.Vector.Storable  (Vector)
import Data.Vector           ((!))
import Numeric.LinearAlgebra ((<\>))
import Numeric.ModalInterval      (Kaucher, inf, sup, width, singleton)
import qualified Numeric.ModalInterval as Interval 
import Control.Monad.Reader
import qualified Data.Vector.Storable  as VS
import qualified Data.Vector           as V
import qualified Numeric.LinearAlgebra as LA
import Data.Maybe (fromMaybe)
import MachineLearning.Model.Measure       (Measure)
import MachineLearning.Model.Regression    (nonlinearFit, evalPenalty, fitTask, predictTask, applyMeasures)
import MachineLearning.TIR       (TIR(..),  Individual(..), Dataset, Constraint, assembleTree, replaceConsts)
import MachineLearning.Utils.Config       (Task(..), Penalty)
import Data.SRTree                (SRTree(..), Function, OptIntPow(..), evalTree, evalTreeMap, evalFun, inverseFunc, countNodes)

import Data.SRTree.Print

-- | removes invalid terms from the TIR expression. Invalid terms
-- are those that evaluate to `NaN` or `Infinite` within the
-- domains of each variable. The domains are either provided by
-- the configuration file or estimated using the training data.
selectValidTerms :: TIR -> V.Vector (Kaucher Double) -> TIR
selectValidTerms :: TIR -> Vector (Kaucher Double) -> TIR
selectValidTerms tir :: TIR
tir@(TIR Function
_ Sigma
p Sigma
q) Vector (Kaucher Double)
domains = TIR
tir{ _p :: Sigma
_p=Sigma
p', _q :: Sigma
_q=Sigma
q' }
  where
    p' :: Sigma
p' = Sigma -> Sigma
forall {a}.
[(a, Function, [(Int, Int)])] -> [(a, Function, [(Int, Int)])]
selectValid Sigma
p
    q' :: Sigma
q' = Sigma -> Sigma
forall {a}.
[(a, Function, [(Int, Int)])] -> [(a, Function, [(Int, Int)])]
selectValid Sigma
q
    
    selectValid :: [(a, Function, [(Int, Int)])] -> [(a, Function, [(Int, Int)])]
selectValid = ((a, Function, [(Int, Int)]) -> Bool)
-> [(a, Function, [(Int, Int)])] -> [(a, Function, [(Int, Int)])]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(a
_, Function
g, [(Int, Int)]
ps) -> Kaucher Double -> Bool
isValidInterval (Kaucher Double -> Bool) -> Kaucher Double -> Bool
forall a b. (a -> b) -> a -> b
$ Function -> Kaucher Double -> Kaucher Double
forall val. Floating val => Function -> val -> val
evalFun Function
g ([(Int, Int)] -> Kaucher Double
evalPi [(Int, Int)]
ps))
    evalPi :: [(Int, Int)] -> Kaucher Double
evalPi      = ((Int, Int) -> Kaucher Double -> Kaucher Double)
-> Kaucher Double -> [(Int, Int)] -> Kaucher Double
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(Int
ix, Int
k) Kaucher Double
acc -> Kaucher Double
acc Kaucher Double -> Kaucher Double -> Kaucher Double
forall a. Num a => a -> a -> a
* (Vector (Kaucher Double)
domains Vector (Kaucher Double) -> Int -> Kaucher Double
forall a. Vector a -> Int -> a
! Int
ix Kaucher Double -> Int -> Kaucher Double
forall a. OptIntPow a => a -> Int -> a
^. Int
k)) Kaucher Double
1 
{-# INLINE selectValidTerms #-}

isValidInterval :: Kaucher Double -> Bool
isValidInterval = Bool -> Bool
not(Bool -> Bool)
-> (Kaucher Double -> Bool) -> Kaucher Double -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kaucher Double -> Bool
isInvalidInterval
{-# INLINE isValidInterval #-}

isInvalidInterval :: Kaucher Double -> Bool                        
isInvalidInterval :: Kaucher Double -> Bool
isInvalidInterval Kaucher Double
ys =  Kaucher Double -> Bool
forall a. Kaucher a -> Bool
Interval.isEmpty Kaucher Double
ys 
                     Bool -> Bool -> Bool
|| Kaucher Double -> Bool
forall a. Kaucher a -> Bool
Interval.isInvalid Kaucher Double
ys
                     Bool -> Bool -> Bool
|| Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
ys1 Bool -> Bool -> Bool
|| Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
ys2 
                     Bool -> Bool -> Bool
|| Double
ys2 Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
ys1 
                     Bool -> Bool -> Bool
|| Double -> Double
forall a. Num a => a -> a
abs Double
ys1 Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
1e50 Bool -> Bool -> Bool
|| Double -> Double
forall a. Num a => a -> a
abs Double
ys2 Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
1e50
                     Bool -> Bool -> Bool
|| Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
ys1 Bool -> Bool -> Bool
|| Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
ys2
                     Bool -> Bool -> Bool
|| Kaucher Double -> Double
forall a. Num a => Kaucher a -> a
width Kaucher Double
ys Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
1e-8
  where
    ys1 :: Double
ys1 = Double -> Maybe Double -> Double
forall a. a -> Maybe a -> a
fromMaybe (-Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0) (Maybe Double -> Double) -> Maybe Double -> Double
forall a b. (a -> b) -> a -> b
$ Kaucher Double -> Maybe Double
forall a. Kaucher a -> Maybe a
inf Kaucher Double
ys
    ys2 :: Double
ys2 = Double -> Maybe Double -> Double
forall a. a -> Maybe a -> a
fromMaybe (Double
1Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
0) (Maybe Double -> Double) -> Maybe Double -> Double
forall a b. (a -> b) -> a -> b
$ Kaucher Double -> Maybe Double
forall a. Kaucher a -> Maybe a
sup Kaucher Double
ys
{-# INLINE isInvalidInterval #-}            

-- | evaluates an individual first fitting the expression
-- with either OLS or a nonlinear optimization (not yet implemented) 
-- and calculating the fitness vector, constraints, penalty.
evalTrain :: Task                          -- ^ Regression or Classification task
          -> Bool                          -- ^ if we are fitting the final best individual, in this case do not split the training data for validation
          -> [Measure]                     -- ^ list of performance measures to calculate
          -> Constraint                    -- ^ constraint function
          -> Penalty                       -- ^ penalty
          -> V.Vector (Kaucher Double)     -- ^ variable domains represented as a Kaucher Interval
          -> Dataset Double                -- ^ training data
          -> Vector Double                 -- ^ training target
          -> Dataset Double                -- ^ validation data
          -> Vector Double                 -- ^ validation target
          -> Individual 
          -> Individual
evalTrain :: Task
-> Bool
-> [Measure]
-> Constraint
-> Penalty
-> Vector (Kaucher Double)
-> Dataset Double
-> Vector Double
-> Dataset Double
-> Vector Double
-> Individual
-> Individual
evalTrain Task
task Bool
isRefit [Measure]
measures Constraint
cnstrFun Penalty
penalty Vector (Kaucher Double)
domains Dataset Double
xss_train Vector Double
ys_train Dataset Double
xss_val Vector Double
ys_val Individual
sol
--  | LA.cols zss == 0                   = error "found"
--  | (not.null) (LA.find (\x -> isNaN x || isInfinite x) zss)  = error $ (show $ _chromo sol) <> show domains 
  | Bool -> Bool
not Bool
isRefit Bool -> Bool -> Bool
&& (Bool -> Bool
not(Bool -> Bool) -> (Individual -> Bool) -> Individual -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
.[Double] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null([Double] -> Bool)
-> (Individual -> [Double]) -> Individual -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Individual -> [Double]
_fit) Individual
sol = Individual
sol
--  | LA.cols zssP == 0                  = sol { _fit = [1/0] }
  | Bool
otherwise                          = Individual
sol{ _chromo :: TIR
_chromo  = TIR
fitted
                                            , _fit :: [Double]
_fit     = [Double]
fitness
                                            , _weights :: [Vector Double]
_weights = [Vector Double]
ws
                                            , _constr :: Double
_constr  = Double
cnst
                                            , _len :: Int
_len     = Int
len
                                            , _penalty :: Double
_penalty = Double
pnlty 
                                            }
  where
    -- Fit the rational IT
    tir :: TIR
tir          = TIR -> Vector (Kaucher Double) -> TIR
selectValidTerms (Individual -> TIR
_chromo Individual
sol) Vector (Kaucher Double)
domains
    ws :: [Vector Double]
ws           = Task -> TIR -> Dataset Double -> Vector Double -> [Vector Double]
fitTask Task
task TIR
tir Dataset Double
xss_train Vector Double
ys_train
    
    -- Validate (it should be applied to every different weights set)
    fitted :: TIR
fitted         = TIR -> Vector Double -> TIR
replaceConsts TIR
tir (Vector Double -> TIR)
-> ([Vector Double] -> Vector Double) -> [Vector Double] -> TIR
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Double -> Vector Double
forall a. Vector a -> Vector a
V.tail (Vector Double -> Vector Double)
-> ([Vector Double] -> Vector Double)
-> [Vector Double]
-> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Double -> Vector Double
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VS.convert (Vector Double -> Vector Double)
-> ([Vector Double] -> Vector Double)
-> [Vector Double]
-> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Vector Double] -> Vector Double
forall a. [a] -> a
head
                   ([Vector Double] -> TIR) -> [Vector Double] -> TIR
forall a b. (a -> b) -> a -> b
$ [Vector Double]
ws               
    fitness :: [Double]
fitness        = (Double -> Double) -> [Double] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map Double -> Double
forall a. RealFloat a => a -> a
nan2inf ([Double] -> [Double])
-> (Individual -> [Double]) -> Individual -> [Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe [Double] -> [Double]
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe [Double] -> [Double])
-> (Individual -> Maybe [Double]) -> Individual -> [Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Task
-> [Measure]
-> Dataset Double
-> Vector Double
-> Individual
-> Maybe [Double]
evalTest Task
task [Measure]
measures Dataset Double
xss_val Vector Double
ys_val
                   (Individual -> [Double]) -> Individual -> [Double]
forall a b. (a -> b) -> a -> b
$ Individual
sol{ _chromo :: TIR
_chromo=TIR
fitted, _weights :: [Vector Double]
_weights=[Vector Double]
ws }
    -- Length and constraint   
    tree :: SRTree Int Double
tree           = Double -> TIR -> SRTree Int Double
assembleTree (Vector Double -> Double
forall a. Vector a -> a
V.head (Vector Double -> Double) -> Vector Double -> Double
forall a b. (a -> b) -> a -> b
$ Vector Double -> Vector Double
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VS.convert (Vector Double -> Vector Double) -> Vector Double -> Vector Double
forall a b. (a -> b) -> a -> b
$ [Vector Double] -> Vector Double
forall a. [a] -> a
head [Vector Double]
ws) TIR
fitted
    len :: Int
len            = SRTree Int Double -> Int
forall ix val. SRTree ix val -> Int
countNodes SRTree Int Double
tree
    cnst :: Double
cnst           = Constraint
cnstrFun SRTree Int Double
tree
    pnlty :: Double
pnlty          = Penalty -> Int -> Double -> Double
evalPenalty Penalty
penalty Int
len Double
cnst


-- | Evaluates an expression into the test set. This is different from `fitnessReg` since
-- it doesn't apply OLS.
evalTest :: Task -> [Measure] -> Dataset Double -> Vector Double -> Individual -> Maybe [Double]
evalTest :: Task
-> [Measure]
-> Dataset Double
-> Vector Double
-> Individual
-> Maybe [Double]
evalTest Task
task [Measure]
measures Dataset Double
xss Vector Double
ys Individual
sol
  | [Vector Double] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Vector Double]
weights = Maybe [Double]
forall a. Maybe a
Nothing
  | Bool
otherwise    = [Double] -> Maybe [Double]
forall a. a -> Maybe a
Just
                 ([Double] -> Maybe [Double]) -> [Double] -> Maybe [Double]
forall a b. (a -> b) -> a -> b
$ [Measure] -> Vector Double -> Vector Double -> [Double]
applyMeasures [Measure]
measures Vector Double
ys
                 (Vector Double -> [Double]) -> Vector Double -> [Double]
forall a b. (a -> b) -> a -> b
$ Task -> [Vector Double] -> Vector Double
predictTask Task
task
                 ([Vector Double] -> Vector Double)
-> [Vector Double] -> Vector Double
forall a b. (a -> b) -> a -> b
$ (Vector Double -> Vector Double)
-> [Vector Double] -> [Vector Double]
forall a b. (a -> b) -> [a] -> [b]
map (Dataset Double -> Double -> TIR -> Vector Double
evalTIR Dataset Double
xss' Double
bias (TIR -> Vector Double)
-> (Vector Double -> TIR) -> Vector Double -> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TIR -> Vector Double -> TIR
replaceConsts TIR
tir (Vector Double -> TIR)
-> (Vector Double -> Vector Double) -> Vector Double -> TIR
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Double -> Vector Double
forall a. Vector a -> Vector a
V.tail (Vector Double -> Vector Double)
-> (Vector Double -> Vector Double)
-> Vector Double
-> Vector Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Double -> Vector Double
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VS.convert) [Vector Double]
weights
  where
    tir :: TIR
tir          = Individual -> TIR
_chromo Individual
sol
    weights :: [Vector Double]
weights      = Individual -> [Vector Double]
_weights Individual
sol
    bias :: Double
bias         = Vector Double -> Double
forall a. Vector a -> a
V.head (Vector Double -> Double) -> Vector Double -> Double
forall a b. (a -> b) -> a -> b
$ Vector Double -> Vector Double
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VS.convert (Vector Double -> Vector Double) -> Vector Double -> Vector Double
forall a b. (a -> b) -> a -> b
$ [Vector Double] -> Vector Double
forall a. [a] -> a
head [Vector Double]
weights -- only works for regression
    -- nSamples     = VS.length $ V.head xss
    xss' :: Dataset Double
xss'         = Dataset Double -> Dataset Double
forall a. Vector a -> Vector a
V.tail Dataset Double
xss
    -- treeEval t   = fromJust $ runReader (evalTreeMap (VS.replicate nSamples) t) (xss' V.!?)

evalTIR :: Dataset Double -> Double -> TIR -> LA.Vector Double
evalTIR :: Dataset Double -> Double -> TIR -> Vector Double
evalTIR Dataset Double
xss Double
bias (TIR Function
g Sigma
p Sigma
q) = Function -> Vector Double -> Vector Double
forall val. Floating val => Function -> val -> val
evalFun Function
g ((Double -> Vector Double
forall (c :: * -> *) e. Container c e => e -> c e
LA.scalar Double
bias Vector Double -> Vector Double -> Vector Double
forall a. Num a => a -> a -> a
+ Vector Double
p') Vector Double -> Vector Double -> Vector Double
forall a. Fractional a => a -> a -> a
/ (Vector Double
1 Vector Double -> Vector Double -> Vector Double
forall a. Num a => a -> a -> a
+ Vector Double
q'))
  where
    p' :: Vector Double
p'     = ((Double, Function, [(Int, Int)])
 -> Vector Double -> Vector Double)
-> Vector Double -> Sigma -> Vector Double
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(Double
w, Function
h, [(Int, Int)]
ks) Vector Double
acc -> Double -> Vector Double
forall (c :: * -> *) e. Container c e => e -> c e
LA.scalar Double
w Vector Double -> Vector Double -> Vector Double
forall a. Num a => a -> a -> a
* Function -> Vector Double -> Vector Double
forall val. Floating val => Function -> val -> val
evalFun Function
h ([(Int, Int)] -> Vector Double
evalPi [(Int, Int)]
ks) Vector Double -> Vector Double -> Vector Double
forall a. Num a => a -> a -> a
+ Vector Double
acc) Vector Double
0 Sigma
p
    q' :: Vector Double
q'     = ((Double, Function, [(Int, Int)])
 -> Vector Double -> Vector Double)
-> Vector Double -> Sigma -> Vector Double
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(Double
w, Function
h, [(Int, Int)]
ks) Vector Double
acc -> Double -> Vector Double
forall (c :: * -> *) e. Container c e => e -> c e
LA.scalar Double
w Vector Double -> Vector Double -> Vector Double
forall a. Num a => a -> a -> a
* Function -> Vector Double -> Vector Double
forall val. Floating val => Function -> val -> val
evalFun Function
h ([(Int, Int)] -> Vector Double
evalPi [(Int, Int)]
ks) Vector Double -> Vector Double -> Vector Double
forall a. Num a => a -> a -> a
+ Vector Double
acc) Vector Double
0 Sigma
q
    evalPi :: [(Int, Int)] -> Vector Double
evalPi = ((Int, Int) -> Vector Double -> Vector Double)
-> Vector Double -> [(Int, Int)] -> Vector Double
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(Int
ix,Int
k) Vector Double
acc -> Vector Double
acc Vector Double -> Vector Double -> Vector Double
forall a. Num a => a -> a -> a
* (Dataset Double
xss Dataset Double -> Int -> Vector Double
forall a. Vector a -> Int -> a
! Int
ix)Vector Double -> Int -> Vector Double
forall a b. (Fractional a, Integral b) => a -> b -> a
^^Int
k) Vector Double
1
{-# INLINE evalTIR #-}
    
--instance OptIntPow (LA.Vector Double) where
--  (^.) = (^^)


nan2inf :: RealFloat a => a -> a
nan2inf :: forall a. RealFloat a => a -> a
nan2inf a
x | a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x = a
1a -> a -> a
forall a. Fractional a => a -> a -> a
/a
0
          |Bool
otherwise = a
x
{-# INLINE nan2inf #-}