Skip to content

Commit

Permalink
MPI version of softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoreilly committed Sep 13, 2023
1 parent 343970d commit 0dd9e97
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
41 changes: 41 additions & 0 deletions decoder/softmax.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sort"

"github.com/emer/emergent/emer"
"github.com/emer/empi/mpi"
"github.com/emer/etable/etensor"
"github.com/goki/mat32"
)
Expand Down Expand Up @@ -46,6 +47,12 @@ type SoftMax struct {

// synaptic weights: outer loop is units, inner loop is inputs
Weights etensor.Float32 `desc:"synaptic weights: outer loop is units, inner loop is inputs"`

// [view: -] mpi communicator
Comm *mpi.Comm `view:"-" desc:"mpi communicator"`

// delta weight changes: only for MPI mode -- outer loop is units, inner loop is inputs
MPIDWts etensor.Float32 `desc:"delta weight changes: only for MPI mode -- outer loop is units, inner loop is inputs"`
}

// SoftMaxUnit has variables for softmax decoder unit
Expand Down Expand Up @@ -103,6 +110,13 @@ func (sm *SoftMax) Train(targ int) {
sm.Back()
}

// TrainMPI trains the decoder with given target correct answer (0..NCats-1)
// MPI version uses mpi to synchronize weight changes across parallel nodes.
func (sm *SoftMax) TrainMPI(targ int) {
sm.Target = targ
sm.BackMPI()
}

// ValsTsr gets value tensor of given name, creating if not yet made
func (sm *SoftMax) ValsTsr(name string) *etensor.Float32 {
if sm.ValsTsrs == nil {
Expand Down Expand Up @@ -188,3 +202,30 @@ func (sm *SoftMax) Back() {
}
}
}

// BackMPI compute the backward error propagation pass
func (sm *SoftMax) BackMPI() {
if sm.MPIDWts.Len() == 0 {
sm.MPIDWts.CopyShapeFrom(&sm.Weights)
}
lr := sm.Lrate
for ui := range sm.Units {
u := &sm.Units[ui]
var del float32
if ui == sm.Target {
del = lr * (1 - u.Act)
} else {
del = -lr * u.Act
}
off := ui * sm.NInputs
for j, in := range sm.Inputs {
sm.MPIDWts.Values[off+j] = del * in
}
}

sm.Comm.AllReduceF32(mpi.OpSum, sm.MPIDWts.Values, nil)

for i, dw := range sm.MPIDWts.Values {
sm.Weights.Values[i] += dw
}
}
2 changes: 1 addition & 1 deletion elog/stditems.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (lg *Logs) AddStatAggItem(statName string, times ...etime.Times) *Item {
tm := times[i]
if tm == etime.Run || tm == etime.Condition {
itm.Write[etime.Scope(etime.Train, tm)] = func(ctx *Context) {
ix := ctx.LastNRows(etime.Train, times[i+1], 5) // cached
ix := ctx.LastNRows(ctx.Mode, times[i+1], 5) // cached
ctx.SetFloat64(agg.Mean(ix, ctx.Item.Name)[0])
}
} else {
Expand Down

0 comments on commit 0dd9e97

Please sign in to comment.