{-# LANGUAGE FlexibleContexts #-}
module IT where
import Data.List (intercalate)
import qualified Data.IntMap.Strict as M
import qualified Numeric.LinearAlgebra as LA
import qualified Data.Vector as V
type Interaction = M.IntMap Int
data Transformation = Id | Sin | Cos | Tan | Tanh | Sqrt | SqrtAbs | Log | Exp | Log1p
deriving (Int -> Transformation -> ShowS
[Transformation] -> ShowS
Transformation -> String
(Int -> Transformation -> ShowS)
-> (Transformation -> String)
-> ([Transformation] -> ShowS)
-> Show Transformation
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Transformation] -> ShowS
$cshowList :: [Transformation] -> ShowS
show :: Transformation -> String
$cshow :: Transformation -> String
showsPrec :: Int -> Transformation -> ShowS
$cshowsPrec :: Int -> Transformation -> ShowS
Show, ReadPrec [Transformation]
ReadPrec Transformation
Int -> ReadS Transformation
ReadS [Transformation]
(Int -> ReadS Transformation)
-> ReadS [Transformation]
-> ReadPrec Transformation
-> ReadPrec [Transformation]
-> Read Transformation
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Transformation]
$creadListPrec :: ReadPrec [Transformation]
readPrec :: ReadPrec Transformation
$creadPrec :: ReadPrec Transformation
readList :: ReadS [Transformation]
$creadList :: ReadS [Transformation]
readsPrec :: Int -> ReadS Transformation
$creadsPrec :: Int -> ReadS Transformation
Read, Transformation -> Transformation -> Bool
(Transformation -> Transformation -> Bool)
-> (Transformation -> Transformation -> Bool) -> Eq Transformation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Transformation -> Transformation -> Bool
$c/= :: Transformation -> Transformation -> Bool
== :: Transformation -> Transformation -> Bool
$c== :: Transformation -> Transformation -> Bool
Eq, Eq Transformation
Eq Transformation
-> (Transformation -> Transformation -> Ordering)
-> (Transformation -> Transformation -> Bool)
-> (Transformation -> Transformation -> Bool)
-> (Transformation -> Transformation -> Bool)
-> (Transformation -> Transformation -> Bool)
-> (Transformation -> Transformation -> Transformation)
-> (Transformation -> Transformation -> Transformation)
-> Ord Transformation
Transformation -> Transformation -> Bool
Transformation -> Transformation -> Ordering
Transformation -> Transformation -> Transformation
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Transformation -> Transformation -> Transformation
$cmin :: Transformation -> Transformation -> Transformation
max :: Transformation -> Transformation -> Transformation
$cmax :: Transformation -> Transformation -> Transformation
>= :: Transformation -> Transformation -> Bool
$c>= :: Transformation -> Transformation -> Bool
> :: Transformation -> Transformation -> Bool
$c> :: Transformation -> Transformation -> Bool
<= :: Transformation -> Transformation -> Bool
$c<= :: Transformation -> Transformation -> Bool
< :: Transformation -> Transformation -> Bool
$c< :: Transformation -> Transformation -> Bool
compare :: Transformation -> Transformation -> Ordering
$ccompare :: Transformation -> Transformation -> Ordering
$cp1Ord :: Eq Transformation
Ord)
data Term = Term Transformation Interaction
deriving (Int -> Term -> ShowS
[Term] -> ShowS
Term -> String
(Int -> Term -> ShowS)
-> (Term -> String) -> ([Term] -> ShowS) -> Show Term
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Term] -> ShowS
$cshowList :: [Term] -> ShowS
show :: Term -> String
$cshow :: Term -> String
showsPrec :: Int -> Term -> ShowS
$cshowsPrec :: Int -> Term -> ShowS
Show, ReadPrec [Term]
ReadPrec Term
Int -> ReadS Term
ReadS [Term]
(Int -> ReadS Term)
-> ReadS [Term] -> ReadPrec Term -> ReadPrec [Term] -> Read Term
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Term]
$creadListPrec :: ReadPrec [Term]
readPrec :: ReadPrec Term
$creadPrec :: ReadPrec Term
readList :: ReadS [Term]
$creadList :: ReadS [Term]
readsPrec :: Int -> ReadS Term
$creadsPrec :: Int -> ReadS Term
Read)
type Expr = [Term]
type Column a = LA.Vector a
type Dataset a = V.Vector (Column a)
prettyPrint :: ((Int, Int) -> String)
-> (Transformation -> String)
-> Expr
-> [Double]
-> String
prettyPrint :: ((Int, Int) -> String)
-> (Transformation -> String) -> [Term] -> [Double] -> String
prettyPrint (Int, Int) -> String
_ Transformation -> String
_ [] [Double]
_ = String
""
prettyPrint (Int, Int) -> String
_ Transformation -> String
_ [Term]
_ [] = ShowS
forall a. HasCallStack => String -> a
error String
"ERROR: prettyPrint on non fitted expression."
prettyPrint (Int, Int) -> String
k2str Transformation -> String
t2str [Term]
terms (Double
b:[Double]
ws) = Double -> String
forall a. Show a => a -> String
show Double
b String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" + " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
expr
where
expr :: String
expr = String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
" + " ((Double -> ShowS) -> [Double] -> [String] -> [String]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Double -> ShowS
forall a. Show a => a -> ShowS
weight2str [Double]
ws ((Term -> String) -> [Term] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Term -> String
terms2str [Term]
terms))
interaction2str :: IntMap Int -> String
interaction2str = String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"*" ([String] -> String)
-> (IntMap Int -> [String]) -> IntMap Int -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> Bool) -> [String] -> [String]
forall a. (a -> Bool) -> [a] -> [a]
filter (String -> String -> Bool
forall a. Eq a => a -> a -> Bool
/=String
"") ([String] -> [String])
-> (IntMap Int -> [String]) -> IntMap Int -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, Int) -> String) -> [(Int, Int)] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Int) -> String
k2str ([(Int, Int)] -> [String])
-> (IntMap Int -> [(Int, Int)]) -> IntMap Int -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntMap Int -> [(Int, Int)]
forall a. IntMap a -> [(Int, a)]
M.toList
terms2str :: Term -> String
terms2str (Term Transformation
t IntMap Int
ks) = Transformation -> String
t2str Transformation
t String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ IntMap Int -> String
interaction2str IntMap Int
ks String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
weight2str :: a -> ShowS
weight2str a
w String
t = a -> String
forall a. Show a => a -> String
show a
w String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"*" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
t
toExprStr :: Expr -> [Double] -> String
toExprStr :: [Term] -> [Double] -> String
toExprStr = ((Int, Int) -> String)
-> (Transformation -> String) -> [Term] -> [Double] -> String
prettyPrint (Int, Int) -> String
forall a a. (Eq a, Num a, Show a, Show a) => (a, a) -> String
k2str Transformation -> String
forall a. Show a => a -> String
show
where
k2str :: (a, a) -> String
k2str (a
_, a
0) = String
""
k2str (a
n, a
1) = Char
'x' Char -> ShowS
forall a. a -> [a] -> [a]
: a -> String
forall a. Show a => a -> String
show a
n
k2str (a
n, a
k) = (Char
'x' Char -> ShowS
forall a. a -> [a] -> [a]
: a -> String
forall a. Show a => a -> String
show a
n) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"^(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
k String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
toPython :: Expr -> [Double] -> String
toPython :: [Term] -> [Double] -> String
toPython = ((Int, Int) -> String)
-> (Transformation -> String) -> [Term] -> [Double] -> String
prettyPrint (Int, Int) -> String
forall a a. (Eq a, Num a, Show a, Show a) => (a, a) -> String
k2str Transformation -> String
numpy
where
k2str :: (a, a) -> String
k2str (a
_, a
0) = String
""
k2str (a
n, a
1) = String
"x[:," String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]"
k2str (a
n, a
k) = String
"x[:," String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"**(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
k String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
numpy :: Transformation -> String
numpy Transformation
Id = String
""
numpy Transformation
Sin = String
"np.sin"
numpy Transformation
Cos = String
"np.cos"
numpy Transformation
Tan = String
"np.tan"
numpy Transformation
Tanh = String
"np.tanh"
numpy Transformation
Sqrt = String
"np.sqrt"
numpy Transformation
SqrtAbs = String
"sqrtAbs"
numpy Transformation
Exp = String
"np.exp"
numpy Transformation
Log = String
"np.log"
numpy Transformation
Log1p = String
"np.log1p"
instance Eq Term where
(Term Transformation
_ IntMap Int
i1) == :: Term -> Term -> Bool
== (Term Transformation
_ IntMap Int
i2) = IntMap Int
i1 IntMap Int -> IntMap Int -> Bool
forall a. Eq a => a -> a -> Bool
== IntMap Int
i2
removeIthTerm :: Int -> Expr -> Expr
removeIthTerm :: Int -> [Term] -> [Term]
removeIthTerm Int
i [Term]
terms = Int -> [Term] -> [Term]
forall a. Int -> [a] -> [a]
take Int
i [Term]
terms [Term] -> [Term] -> [Term]
forall a. [a] -> [a] -> [a]
++ Int -> [Term] -> [Term]
forall a. Int -> [a] -> [a]
drop (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Term]
terms
getIthTerm :: Int -> Expr -> Maybe Term
getIthTerm :: Int -> [Term] -> Maybe Term
getIthTerm Int
ix [Term]
terms = if Int
ix Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= [Term] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Term]
terms
then Maybe Term
forall a. Maybe a
Nothing
else Term -> Maybe Term
forall a. a -> Maybe a
Just ([Term]
terms [Term] -> Int -> Term
forall a. [a] -> Int -> a
!! Int
ix)
getInteractions :: Term -> Interaction
getInteractions :: Term -> IntMap Int
getInteractions (Term Transformation
_ IntMap Int
ks) = IntMap Int
ks
exprLength :: Expr -> [Double] -> Int
exprLength :: [Term] -> [Double] -> Int
exprLength [] [Double]
_ = Int
0
exprLength [Term]
_ [] = String -> Int
forall a. HasCallStack => String -> a
error String
"ERROR: length of unfitted expression."
exprLength [Term]
terms (Double
b:[Double]
ws) = Int
biasTerm Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
addSymbs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
weightSymbs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
totTerms
where
nonNullTerms :: [(Double, Term)]
nonNullTerms = ((Double, Term) -> Bool) -> [(Double, Term)] -> [(Double, Term)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/=Double
0)(Double -> Bool)
-> ((Double, Term) -> Double) -> (Double, Term) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Double, Term) -> Double
forall a b. (a, b) -> a
fst) ([(Double, Term)] -> [(Double, Term)])
-> [(Double, Term)] -> [(Double, Term)]
forall a b. (a -> b) -> a -> b
$ [Double] -> [Term] -> [(Double, Term)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Double]
ws [Term]
terms
ws' :: [Double]
ws' = ((Double, Term) -> Double) -> [(Double, Term)] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (Double, Term) -> Double
forall a b. (a, b) -> a
fst [(Double, Term)]
nonNullTerms
terms' :: [Term]
terms' = ((Double, Term) -> Term) -> [(Double, Term)] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map (Double, Term) -> Term
forall a b. (a, b) -> b
snd [(Double, Term)]
nonNullTerms
biasTerm :: Int
biasTerm = if Double
bDouble -> Double -> Bool
forall a. Eq a => a -> a -> Bool
==Double
0 then Int
0 else Int
2
weightSymbs :: Int
weightSymbs = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* [Double] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((Double -> Bool) -> [Double] -> [Double]
forall a. (a -> Bool) -> [a] -> [a]
filter (Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/=Double
1) [Double]
ws')
addSymbs :: Int
addSymbs = [Term] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Term]
terms' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
totTerms :: Int
totTerms = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Term -> Int) -> [Term] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Term -> Int
termLength [Term]
terms)
termLength :: Term -> Int
termLength :: Term -> Int
termLength (Term Transformation
Id IntMap Int
ks) = IntMap Int -> Int
interactionLength IntMap Int
ks
termLength (Term Transformation
_ IntMap Int
ks) = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ IntMap Int -> Int
interactionLength IntMap Int
ks
interactionLength :: Interaction -> Int
interactionLength :: IntMap Int -> Int
interactionLength IntMap Int
ks = Int
mulSymbs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
termSymbs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
powSymbs
where
elems :: [Int]
elems = (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/=Int
0) ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ IntMap Int -> [Int]
forall a. IntMap a -> [a]
M.elems IntMap Int
ks
mulSymbs :: Int
mulSymbs = [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
elems Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
termSymbs :: Int
termSymbs = [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
elems
powSymbs :: Int
powSymbs = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/=Int
1) [Int]
elems)