money/rnn.go

354 lines
9.2 KiB
Go

package main
import (
"log"
"math"
"math/rand"
"gonum.org/v1/gonum/mat"
)
/* Dataset + LSTM + helper come elencato */
// ================== LSTM (gonum) ==================
type LSTM struct {
InputSize int
HiddenSize int
Wxi, Whi, bi *mat.Dense
Wxf, Whf, bf *mat.Dense
Wxo, Who, bo *mat.Dense
Wxg, Whg, bg *mat.Dense
Why, by *mat.Dense
h, c *mat.Dense
LR float64
}
func newLSTM(inputSize, hiddenSize int, lr float64, seed int64) *LSTM {
rng := rand.New(rand.NewSource(seed))
init := func(r, c int) *mat.Dense {
data := make([]float64, r*c)
scale := 1.0 / math.Sqrt(float64(c))
for i := range data {
data[i] = (rng.Float64()*2 - 1) * 0.1 * scale
}
return mat.NewDense(r, c, data)
}
zeros := func(r, c int) *mat.Dense { return mat.NewDense(r, c, nil) }
return &LSTM{
InputSize: inputSize,
HiddenSize: hiddenSize,
Wxi: init(hiddenSize, inputSize),
Whi: init(hiddenSize, hiddenSize),
bi: zeros(hiddenSize, 1),
Wxf: init(hiddenSize, inputSize),
Whf: init(hiddenSize, hiddenSize),
bf: zeros(hiddenSize, 1),
Wxo: init(hiddenSize, inputSize),
Who: init(hiddenSize, hiddenSize),
bo: zeros(hiddenSize, 1),
Wxg: init(hiddenSize, inputSize),
Whg: init(hiddenSize, hiddenSize),
bg: zeros(hiddenSize, 1),
Why: init(1, hiddenSize),
by: zeros(1, 1),
h: zeros(hiddenSize, 1),
c: zeros(hiddenSize, 1),
LR: lr,
}
}
func (m *LSTM) resetState() { m.h.Zero(); m.c.Zero() }
func denseOf(a mat.Matrix) *mat.Dense {
r, c := a.Dims()
out := mat.NewDense(r, c, nil)
out.Copy(a)
return out
}
func mm(a mat.Matrix, b mat.Matrix) *mat.Dense { var out mat.Dense; out.Mul(a, b); return &out }
func addM(a mat.Matrix, b mat.Matrix) *mat.Dense { da := denseOf(a); da.Add(da, b); return da }
func hadM(a mat.Matrix, b mat.Matrix) *mat.Dense {
ar, ac := a.Dims()
br, bc := b.Dims()
if ar != br || ac != bc {
panic("hadM: dimensioni incompatibili")
}
out := mat.NewDense(ar, ac, nil)
for i := 0; i < ar; i++ {
for j := 0; j < ac; j++ {
out.Set(i, j, a.At(i, j)*b.At(i, j))
}
}
return out
}
func applyM(a mat.Matrix, f func(float64) float64) *mat.Dense {
r, c := a.Dims()
out := mat.NewDense(r, c, nil)
for i := 0; i < r; i++ {
for j := 0; j < c; j++ {
out.Set(i, j, f(a.At(i, j)))
}
}
return out
}
func scaleM(a mat.Matrix, s float64) *mat.Dense {
return applyM(a, func(v float64) float64 { return v * s })
}
func sigmoid(x float64) float64 { return 1.0 / (1.0 + math.Exp(-x)) }
type lstmCache struct {
xs, hs, cs []*mat.Dense
is, fs, os []*mat.Dense
gs, ys []*mat.Dense
}
func (m *LSTM) forward(seq []float64) (float64, lstmCache) {
T := len(seq)
cache := lstmCache{
xs: make([]*mat.Dense, T), hs: make([]*mat.Dense, T+1), cs: make([]*mat.Dense, T+1),
is: make([]*mat.Dense, T), fs: make([]*mat.Dense, T), os: make([]*mat.Dense, T),
gs: make([]*mat.Dense, T), ys: make([]*mat.Dense, T),
}
cache.hs[0] = denseOf(m.h)
cache.cs[0] = denseOf(m.c)
h := denseOf(m.h)
c := denseOf(m.c)
for t := 0; t < T; t++ {
xt := mat.NewDense(m.InputSize, 1, []float64{seq[t]})
cache.xs[t] = xt
i := applyM(addM(addM(mm(m.Wxi, xt), mm(m.Whi, h)), m.bi), sigmoid)
f := applyM(addM(addM(mm(m.Wxf, xt), mm(m.Whf, h)), m.bf), sigmoid)
o := applyM(addM(addM(mm(m.Wxo, xt), mm(m.Who, h)), m.bo), sigmoid)
g := applyM(addM(addM(mm(m.Wxg, xt), mm(m.Whg, h)), m.bg), math.Tanh)
c = addM(hadM(f, c), hadM(i, g))
h = hadM(o, applyM(c, math.Tanh))
y := addM(mm(m.Why, h), m.by)
cache.is[t] = i
cache.fs[t] = f
cache.os[t] = o
cache.gs[t] = g
cache.hs[t+1] = denseOf(h)
cache.cs[t+1] = denseOf(c)
cache.ys[t] = y
}
pred := cache.ys[T-1].At(0, 0)
m.h.Copy(cache.hs[T])
m.c.Copy(cache.cs[T])
return pred, cache
}
func (m *LSTM) backward(cache lstmCache, target float64) float64 {
T := len(cache.xs)
yT := cache.ys[T-1].At(0, 0)
// Huber loss
delta := 1.0
err := yT - target
var loss, grad float64
if math.Abs(err) <= delta {
loss = 0.5 * err * err
grad = err
} else {
loss = delta*math.Abs(err) - 0.5*delta*delta
if err >= 0 {
grad = delta
} else {
grad = -delta
}
}
dy := mat.NewDense(1, 1, []float64{grad})
var dWhy mat.Dense
dWhy.Mul(dy, cache.hs[T].T())
dby := dy
var dh mat.Dense
dh.Mul(m.Why.T(), dy)
dWxi := mat.NewDense(m.HiddenSize, m.InputSize, nil)
dWhi := mat.NewDense(m.HiddenSize, m.HiddenSize, nil)
dbi := mat.NewDense(m.HiddenSize, 1, nil)
dWxf := mat.NewDense(m.HiddenSize, m.InputSize, nil)
dWhf := mat.NewDense(m.HiddenSize, m.HiddenSize, nil)
dbf := mat.NewDense(m.HiddenSize, 1, nil)
dWxo := mat.NewDense(m.HiddenSize, m.InputSize, nil)
dWho := mat.NewDense(m.HiddenSize, m.HiddenSize, nil)
dbo := mat.NewDense(m.HiddenSize, 1, nil)
dWxg := mat.NewDense(m.HiddenSize, m.InputSize, nil)
dWhg := mat.NewDense(m.HiddenSize, m.HiddenSize, nil)
dbg := mat.NewDense(m.HiddenSize, 1, nil)
dcNext := mat.NewDense(m.HiddenSize, 1, nil)
for t := T - 1; t >= 0; t-- {
h := cache.hs[t+1]
c := cache.cs[t+1]
cPrev := cache.cs[t]
i := cache.is[t]
f := cache.fs[t]
o := cache.os[t]
g := cache.gs[t]
x := cache.xs[t]
hPrev := cache.hs[t]
tanhc := applyM(c, math.Tanh)
do := hadM(&dh, tanhc)
o1mo := applyM(o, func(v float64) float64 { return v * (1 - v) })
do = hadM(do, o1mo)
oneMinusTanh2 := applyM(tanhc, func(v float64) float64 { return 1 - v*v })
tmp := hadM(&dh, o)
dc := hadM(tmp, oneMinusTanh2)
dc = addM(dc, dcNext)
df := hadM(dc, cPrev)
f1mf := applyM(f, func(v float64) float64 { return v * (1 - v) })
df = hadM(df, f1mf)
di := hadM(dc, g)
i1mi := applyM(i, func(v float64) float64 { return v * (1 - v) })
di = hadM(di, i1mi)
dg := hadM(dc, i)
g1mg2 := applyM(g, func(v float64) float64 { return 1 - v*v })
dg = hadM(dg, g1mg2)
dWxi.Add(dWxi, mm(di, x.T()))
dWhi.Add(dWhi, mm(di, hPrev.T()))
dbi.Add(dbi, di)
dWxf.Add(dWxf, mm(df, x.T()))
dWhf.Add(dWhf, mm(df, hPrev.T()))
dbf.Add(dbf, df)
dWxo.Add(dWxo, mm(do, x.T()))
dWho.Add(dWho, mm(do, hPrev.T()))
dbo.Add(dbo, do)
dWxg.Add(dWxg, mm(dg, x.T()))
dWhg.Add(dWhg, mm(dg, hPrev.T()))
dbg.Add(dbg, dg)
var dhi, dhf, dho, dhg mat.Dense
dhi.Mul(m.Whi.T(), di)
dhf.Mul(m.Whf.T(), df)
dho.Mul(m.Who.T(), do)
dhg.Mul(m.Whg.T(), dg)
var dhSum mat.Dense
dhSum.Add(&dhi, &dhf)
dhSum.Add(&dhSum, &dho)
dhSum.Add(&dhSum, &dhg)
dh = dhSum
dcNext = hadM(dc, f)
_ = h
}
clip := func(d *mat.Dense, maxNorm float64) {
r, c := d.Dims()
sum := 0.0
for i := 0; i < r; i++ {
for j := 0; j < c; j++ {
v := d.At(i, j)
sum += v * v
}
}
norm := math.Sqrt(sum)
if norm > maxNorm && norm > 0 {
f := maxNorm / norm
for i := 0; i < r; i++ {
for j := 0; j < c; j++ {
d.Set(i, j, d.At(i, j)*f)
}
}
}
}
maxG := 5.0
clip(dWxi, maxG)
clip(dWhi, maxG)
clip(dbi, maxG)
clip(dWxf, maxG)
clip(dWhf, maxG)
clip(dbf, maxG)
clip(dWxo, maxG)
clip(dWho, maxG)
clip(dbo, maxG)
clip(dWxg, maxG)
clip(dWhg, maxG)
clip(dbg, maxG)
clip(&dWhy, maxG)
clip(dby, maxG)
negLR := -m.LR
m.Wxi.Add(m.Wxi, scaleM(dWxi, negLR))
m.Whi.Add(m.Whi, scaleM(dWhi, negLR))
m.bi.Add(m.bi, scaleM(dbi, negLR))
m.Wxf.Add(m.Wxf, scaleM(dWxf, negLR))
m.Whf.Add(m.Whf, scaleM(dWhf, negLR))
m.bf.Add(m.bf, scaleM(dbf, negLR))
m.Wxo.Add(m.Wxo, scaleM(dWxo, negLR))
m.Who.Add(m.Who, scaleM(dWho, negLR))
m.bo.Add(m.bo, scaleM(dbo, negLR))
m.Wxg.Add(m.Wxg, scaleM(dWxg, negLR))
m.Whg.Add(m.Whg, scaleM(dWhg, negLR))
m.bg.Add(m.bg, scaleM(dbg, negLR))
m.Why.Add(m.Why, scaleM(&dWhy, negLR))
m.by.Add(m.by, scaleM(dby, negLR))
return loss
}
func (m *LSTM) train(dsTrain, dsVal Dataset, batchSize int, maxEpochs int, earlyStopFrac float64, seed int64) (firstLoss float64, bestValMAE float64, epochsRun int) {
rng := rand.New(rand.NewSource(seed))
indices := func(n int) []int {
idx := make([]int, n)
for i := range idx {
idx[i] = i
}
rng.Shuffle(n, func(i, j int) { idx[i], idx[j] = idx[j], idx[i] })
return idx
}
valMAE := func() float64 {
if len(dsVal.Seqs) == 0 {
return math.NaN()
}
sum := 0.0
for i := range dsVal.Seqs {
m.resetState()
p, _ := m.forward(dsVal.Seqs[i])
sum += math.Abs(p - dsVal.Labels[i])
}
return sum / float64(len(dsVal.Seqs))
}
bestValMAE = math.Inf(1)
var baseLoss float64
for epoch := 1; epoch <= maxEpochs; epoch++ {
idx := indices(len(dsTrain.Seqs))
totalLoss := 0.0
n := 0
for start := 0; start < len(idx); start += batchSize {
end := start + batchSize
if end > len(idx) {
end = len(idx)
}
for _, ii := range idx[start:end] {
m.resetState()
_, cache := m.forward(dsTrain.Seqs[ii])
loss := m.backward(cache, dsTrain.Labels[ii])
totalLoss += loss
n++
}
}
avgLoss := totalLoss / float64(max(1, n))
curVal := valMAE()
if epoch == 1 {
firstLoss = avgLoss
baseLoss = avgLoss
}
if curVal < bestValMAE {
bestValMAE = curVal
}
epochsRun = epoch
log.Printf("epoca=%d avgLoss=%.6f firstLoss=%.6f valMAE=%.6f", epoch, avgLoss, firstLoss, curVal)
if avgLoss <= baseLoss*earlyStopFrac {
log.Printf("early-stopping: avgLoss=%.6f soglia=%.6f (%.2f%% di firstLoss) epoca=%d", avgLoss, baseLoss*earlyStopFrac, earlyStopFrac*100, epoch)
break
}
}
return
}