/
MutualInformation.go
53 lines (43 loc) · 1.19 KB
/
MutualInformation.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
package discrete
import (
"math"
)
// MutualInformation calculates the mutual information with the given lnFunc function
// I(X,Y) = \sum_x,y p(x,y) (lnFunc(p(x,y)) - lnFunc(p(x)p(y)))
func MutualInformation(pxy [][]float64, log lnFunc) float64 {
xDim := len(pxy)
yDim := len(pxy[0])
px := make([]float64, xDim)
py := make([]float64, yDim)
for x := 0; x < xDim; x++ {
for y := 0; y < yDim; y++ {
px[x] += pxy[x][y]
}
}
for x := 0; x < xDim; x++ {
for y := 0; y < yDim; y++ {
py[y] += pxy[x][y]
}
}
mi := 0.0
for x := 0; x < xDim; x++ {
if px[x] > 0.0 {
for y := 0; y < yDim; y++ {
if py[y] > 0.0 && pxy[x][y] > 0.0 {
mi += pxy[x][y] * (log(pxy[x][y]) - log(px[x]*py[y]))
}
}
}
}
return mi
}
// MutualInformationBaseE calculates the mutual information with base e
// I(X,Y) = \sum_x,y p(x,y) (ln(p(x,y)) - ln(p(x)p(y)))
func MutualInformationBaseE(pxy [][]float64) float64 {
return MutualInformation(pxy, math.Log)
}
// MutualInformationBase2 calculates the mutual information with base 2
// I(X,Y) = \sum_x,y p(x,y) (log2(p(x,y)) - log2(p(x)p(y)))
func MutualInformationBase2(pxy [][]float64) float64 {
return MutualInformation(pxy, math.Log2)
}