{-# LANGUAGE FlexibleContexts #-}
{-|
Module      : IT
Description : IT expression data structures
Copyright   : (c) Fabricio Olivetti de Franca, 2020
License     : GPL-3
Maintainer  : fabricio.olivetti@gmail.com
Stability   : experimental
Portability : POSIX

An IT expression  represents a function of the form:

\[
f(x) = \sum_{i}{w_i \cdot t_i(\prod_{j}{x_j^{k_{ij}})}}
\]

with \(t_i\) being a transformation function.

Any given expression can be represented by a list of terms, with each term
being composed of a transformatioon function and an interaction.
The transformation function is represented by a `Transformation` sum type.
The interaction is represented as an `IntMap Int` where the key is the 
predictor index and the value is the strength of the predictor in this
term. Strengths with a value of zero are omitted.
-}
module IT where

import Data.List (intercalate)

import qualified Data.IntMap.Strict as M
import qualified Numeric.LinearAlgebra as LA
import qualified Data.Vector as V

-- | The 'Interaction' type is a map where
-- a key, value pair (i,p) indicates that the i-th
-- variable should be raised to the power of p.
type Interaction = M.IntMap Int

-- | The 'Transformation' type describes the function that 
-- should be evaluated. The evaluation is defined in 'IT.Eval' module.
data Transformation = Id | Sin | Cos | Tan | Tanh | Sqrt | SqrtAbs | Log | Exp | Log1p
                        deriving (Int -> Transformation -> ShowS
[Transformation] -> ShowS
Transformation -> String
(Int -> Transformation -> ShowS)
-> (Transformation -> String)
-> ([Transformation] -> ShowS)
-> Show Transformation
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Transformation] -> ShowS
$cshowList :: [Transformation] -> ShowS
show :: Transformation -> String
$cshow :: Transformation -> String
showsPrec :: Int -> Transformation -> ShowS
$cshowsPrec :: Int -> Transformation -> ShowS
Show, ReadPrec [Transformation]
ReadPrec Transformation
Int -> ReadS Transformation
ReadS [Transformation]
(Int -> ReadS Transformation)
-> ReadS [Transformation]
-> ReadPrec Transformation
-> ReadPrec [Transformation]
-> Read Transformation
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Transformation]
$creadListPrec :: ReadPrec [Transformation]
readPrec :: ReadPrec Transformation
$creadPrec :: ReadPrec Transformation
readList :: ReadS [Transformation]
$creadList :: ReadS [Transformation]
readsPrec :: Int -> ReadS Transformation
$creadsPrec :: Int -> ReadS Transformation
Read, Transformation -> Transformation -> Bool
(Transformation -> Transformation -> Bool)
-> (Transformation -> Transformation -> Bool) -> Eq Transformation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Transformation -> Transformation -> Bool
$c/= :: Transformation -> Transformation -> Bool
== :: Transformation -> Transformation -> Bool
$c== :: Transformation -> Transformation -> Bool
Eq, Eq Transformation
Eq Transformation
-> (Transformation -> Transformation -> Ordering)
-> (Transformation -> Transformation -> Bool)
-> (Transformation -> Transformation -> Bool)
-> (Transformation -> Transformation -> Bool)
-> (Transformation -> Transformation -> Bool)
-> (Transformation -> Transformation -> Transformation)
-> (Transformation -> Transformation -> Transformation)
-> Ord Transformation
Transformation -> Transformation -> Bool
Transformation -> Transformation -> Ordering
Transformation -> Transformation -> Transformation
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Transformation -> Transformation -> Transformation
$cmin :: Transformation -> Transformation -> Transformation
max :: Transformation -> Transformation -> Transformation
$cmax :: Transformation -> Transformation -> Transformation
>= :: Transformation -> Transformation -> Bool
$c>= :: Transformation -> Transformation -> Bool
> :: Transformation -> Transformation -> Bool
$c> :: Transformation -> Transformation -> Bool
<= :: Transformation -> Transformation -> Bool
$c<= :: Transformation -> Transformation -> Bool
< :: Transformation -> Transformation -> Bool
$c< :: Transformation -> Transformation -> Bool
compare :: Transformation -> Transformation -> Ordering
$ccompare :: Transformation -> Transformation -> Ordering
$cp1Ord :: Eq Transformation
Ord)

-- | A 'Term' is the product type of a 'Transformation' and an 'Interaction'.
data Term = Term Transformation Interaction
              deriving (Int -> Term -> ShowS
[Term] -> ShowS
Term -> String
(Int -> Term -> ShowS)
-> (Term -> String) -> ([Term] -> ShowS) -> Show Term
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Term] -> ShowS
$cshowList :: [Term] -> ShowS
show :: Term -> String
$cshow :: Term -> String
showsPrec :: Int -> Term -> ShowS
$cshowsPrec :: Int -> Term -> ShowS
Show, ReadPrec [Term]
ReadPrec Term
Int -> ReadS Term
ReadS [Term]
(Int -> ReadS Term)
-> ReadS [Term] -> ReadPrec Term -> ReadPrec [Term] -> Read Term
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Term]
$creadListPrec :: ReadPrec [Term]
readPrec :: ReadPrec Term
$creadPrec :: ReadPrec Term
readList :: ReadS [Term]
$creadList :: ReadS [Term]
readsPrec :: Int -> ReadS Term
$creadsPrec :: Int -> ReadS Term
Read)

-- | An 'Expr' is just a list of 'Term's.
type Expr = [Term]

-- | A 'Column' of a data set is stored as a LA.Vector to avoid conversions
-- during the fitting of the coefficients.
type Column a = LA.Vector a

-- | The 'Dataset' is a 'Vector' of 'Column's for efficiency.
-- TODO: try Repa or Accelerate for large data sets.
type Dataset a = V.Vector (Column a)

-- | Base interface for converting an expression to string 
prettyPrint :: ((Int, Int) -> String)         -- ^ A function that converts an interaction to a string
            -> (Transformation -> String)     -- ^ A function that converts a transformation to a string
            -> Expr                           -- ^ The expression to convert to string 
            -> [Double]                       -- ^ The fitted coefficients
            -> String                         -- ^ A string representing the expression 
prettyPrint :: ((Int, Int) -> String)
-> (Transformation -> String) -> [Term] -> [Double] -> String
prettyPrint (Int, Int) -> String
_ Transformation -> String
_ [] [Double]
_ = String
"" -- empty expression 
prettyPrint (Int, Int) -> String
_ Transformation -> String
_ [Term]
_ [] = ShowS
forall a. HasCallStack => String -> a
error String
"ERROR: prettyPrint on non fitted expression." -- no coefficients 
prettyPrint (Int, Int) -> String
k2str Transformation -> String
t2str [Term]
terms (Double
b:[Double]
ws) = Double -> String
forall a. Show a => a -> String
show Double
b String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" + " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
expr
  where
    expr :: String
expr = String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
" + " ((Double -> ShowS) -> [Double] -> [String] -> [String]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Double -> ShowS
forall a. Show a => a -> ShowS
weight2str [Double]
ws ((Term -> String) -> [Term] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Term -> String
terms2str [Term]
terms))

    interaction2str :: IntMap Int -> String
interaction2str       = String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"*" ([String] -> String)
-> (IntMap Int -> [String]) -> IntMap Int -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> Bool) -> [String] -> [String]
forall a. (a -> Bool) -> [a] -> [a]
filter (String -> String -> Bool
forall a. Eq a => a -> a -> Bool
/=String
"") ([String] -> [String])
-> (IntMap Int -> [String]) -> IntMap Int -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, Int) -> String) -> [(Int, Int)] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Int) -> String
k2str ([(Int, Int)] -> [String])
-> (IntMap Int -> [(Int, Int)]) -> IntMap Int -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntMap Int -> [(Int, Int)]
forall a. IntMap a -> [(Int, a)]
M.toList
    terms2str :: Term -> String
terms2str (Term Transformation
t IntMap Int
ks) = Transformation -> String
t2str Transformation
t String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ IntMap Int -> String
interaction2str IntMap Int
ks String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
    weight2str :: a -> ShowS
weight2str a
w String
t        = a -> String
forall a. Show a => a -> String
show a
w String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"*" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
t

-- | Converts an expression to a readable format 
toExprStr :: Expr -> [Double] -> String
toExprStr :: [Term] -> [Double] -> String
toExprStr = ((Int, Int) -> String)
-> (Transformation -> String) -> [Term] -> [Double] -> String
prettyPrint (Int, Int) -> String
forall a a. (Eq a, Num a, Show a, Show a) => (a, a) -> String
k2str Transformation -> String
forall a. Show a => a -> String
show
  where 
    k2str :: (a, a) -> String
k2str (a
_, a
0) = String
""
    k2str (a
n, a
1) = Char
'x' Char -> ShowS
forall a. a -> [a] -> [a]
: a -> String
forall a. Show a => a -> String
show a
n
    k2str (a
n, a
k) = (Char
'x' Char -> ShowS
forall a. a -> [a] -> [a]
: a -> String
forall a. Show a => a -> String
show a
n) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"^(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
k String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

-- | Converts an expression to a numpy compatible format 
toPython :: Expr -> [Double] -> String
toPython :: [Term] -> [Double] -> String
toPython = ((Int, Int) -> String)
-> (Transformation -> String) -> [Term] -> [Double] -> String
prettyPrint (Int, Int) -> String
forall a a. (Eq a, Num a, Show a, Show a) => (a, a) -> String
k2str Transformation -> String
numpy
  where
    k2str :: (a, a) -> String
k2str (a
_, a
0) = String
""
    k2str (a
n, a
1) = String
"x[:," String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]"
    k2str (a
n, a
k) = String
"x[:," String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"**(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
k String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

    numpy :: Transformation -> String
numpy Transformation
Id      = String
""
    numpy Transformation
Sin     = String
"np.sin"
    numpy Transformation
Cos     = String
"np.cos"
    numpy Transformation
Tan     = String
"np.tan"
    numpy Transformation
Tanh    = String
"np.tanh"
    numpy Transformation
Sqrt    = String
"np.sqrt"
    numpy Transformation
SqrtAbs = String
"sqrtAbs"
    numpy Transformation
Exp     = String
"np.exp"
    numpy Transformation
Log     = String
"np.log"
    numpy Transformation
Log1p   = String
"np.log1p"

-- | Two terms are equal if their interactions are equal
-- this instance is used for the mutation operation to avoid adding 
-- two interactions with the same value on the same expression. 
instance Eq Term where
  (Term Transformation
_ IntMap Int
i1) == :: Term -> Term -> Bool
== (Term Transformation
_ IntMap Int
i2) = IntMap Int
i1 IntMap Int -> IntMap Int -> Bool
forall a. Eq a => a -> a -> Bool
== IntMap Int
i2

-- * Internal functions

-- | remove the i-th term of an expression.
removeIthTerm :: Int -> Expr -> Expr 
removeIthTerm :: Int -> [Term] -> [Term]
removeIthTerm Int
i [Term]
terms = Int -> [Term] -> [Term]
forall a. Int -> [a] -> [a]
take Int
i [Term]
terms [Term] -> [Term] -> [Term]
forall a. [a] -> [a] -> [a]
++ Int -> [Term] -> [Term]
forall a. Int -> [a] -> [a]
drop (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Term]
terms

-- | returns the i-th term of an expression.
getIthTerm :: Int -> Expr -> Maybe Term
getIthTerm :: Int -> [Term] -> Maybe Term
getIthTerm Int
ix [Term]
terms = if Int
ix Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= [Term] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Term]
terms 
                       then Maybe Term
forall a. Maybe a
Nothing
                       else Term -> Maybe Term
forall a. a -> Maybe a
Just ([Term]
terms [Term] -> Int -> Term
forall a. [a] -> Int -> a
!! Int
ix)

-- | return the interactions of a term
getInteractions :: Term -> Interaction
getInteractions :: Term -> IntMap Int
getInteractions (Term Transformation
_ IntMap Int
ks) = IntMap Int
ks

-- | returns the length of an expression as in https://github.com/EpistasisLab/regression-benchmark/blob/dev/CONTRIBUTING.md
exprLength :: Expr -> [Double] -> Int
exprLength :: [Term] -> [Double] -> Int
exprLength [] [Double]
_ = Int
0
exprLength [Term]
_ [] = String -> Int
forall a. HasCallStack => String -> a
error String
"ERROR: length of unfitted expression."
exprLength [Term]
terms (Double
b:[Double]
ws) = Int
biasTerm Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
addSymbs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
weightSymbs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
totTerms
  where
    nonNullTerms :: [(Double, Term)]
nonNullTerms = ((Double, Term) -> Bool) -> [(Double, Term)] -> [(Double, Term)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/=Double
0)(Double -> Bool)
-> ((Double, Term) -> Double) -> (Double, Term) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Double, Term) -> Double
forall a b. (a, b) -> a
fst) ([(Double, Term)] -> [(Double, Term)])
-> [(Double, Term)] -> [(Double, Term)]
forall a b. (a -> b) -> a -> b
$ [Double] -> [Term] -> [(Double, Term)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Double]
ws [Term]
terms
    ws' :: [Double]
ws'          = ((Double, Term) -> Double) -> [(Double, Term)] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (Double, Term) -> Double
forall a b. (a, b) -> a
fst [(Double, Term)]
nonNullTerms
    terms' :: [Term]
terms'       = ((Double, Term) -> Term) -> [(Double, Term)] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map (Double, Term) -> Term
forall a b. (a, b) -> b
snd [(Double, Term)]
nonNullTerms
    biasTerm :: Int
biasTerm     = if Double
bDouble -> Double -> Bool
forall a. Eq a => a -> a -> Bool
==Double
0 then Int
0 else Int
2
    weightSymbs :: Int
weightSymbs  = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* [Double] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((Double -> Bool) -> [Double] -> [Double]
forall a. (a -> Bool) -> [a] -> [a]
filter (Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/=Double
1) [Double]
ws')
    addSymbs :: Int
addSymbs     = [Term] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Term]
terms' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
    totTerms :: Int
totTerms     = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Term -> Int) -> [Term] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Term -> Int
termLength [Term]
terms)
    
-- | The length of a term is the interaction length plus 1 if the
-- transformation function is not 'Id'. 
termLength :: Term -> Int
termLength :: Term -> Int
termLength (Term Transformation
Id IntMap Int
ks) = IntMap Int -> Int
interactionLength IntMap Int
ks
termLength (Term Transformation
_  IntMap Int
ks) = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ IntMap Int -> Int
interactionLength IntMap Int
ks

-- | The interaction length is calculated as:
-- +2 for every exponent different from 0 and 1 (^k)
-- +1 for every nonzero exponent, except for the first (*)
-- +1 for every nonzero exponent (x_i)
interactionLength :: Interaction -> Int
interactionLength :: IntMap Int -> Int
interactionLength IntMap Int
ks = Int
mulSymbs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
termSymbs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
powSymbs
  where
    elems :: [Int]
elems     = (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/=Int
0) ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ IntMap Int -> [Int]
forall a. IntMap a -> [a]
M.elems IntMap Int
ks
    mulSymbs :: Int
mulSymbs  = [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
elems Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
    termSymbs :: Int
termSymbs = [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
elems
    powSymbs :: Int
powSymbs  = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/=Int
1) [Int]
elems)