diff --git a/api_test.go b/api_test.go index 7511980..606a6e8 100644 --- a/api_test.go +++ b/api_test.go @@ -12,7 +12,7 @@ import ( "github.com/hashicorp/raft" ) -func chunkData(t *testing.T) ([]byte, []raft.Log) { +func chunkData(t *testing.T) ([]byte, []*raft.Log) { data := make([]byte, 6000000) n, err := rand.Read(data) if err != nil && err != io.EOF { @@ -22,14 +22,17 @@ func chunkData(t *testing.T) ([]byte, []raft.Log) { t.Fatalf("expected 6000k bytes to test with, read %d", n) } - logs := make([]raft.Log, 0) + logs := make([]*raft.Log, 0) dur := time.Second + var index uint64 applyFunc := func(l raft.Log, d time.Duration) raft.ApplyFuture { if d != dur { t.Fatalf("expected d to be %v, got %v", time.Second, dur) } - logs = append(logs, l) + index++ + l.Index = index + logs = append(logs, &l) return raft.ApplyFuture(nil) } diff --git a/fsm.go b/fsm.go index 53e0e00..374b477 100644 --- a/fsm.go +++ b/fsm.go @@ -11,6 +11,7 @@ import ( var _ raft.FSM = (*ChunkingFSM)(nil) var _ raft.ConfigurationStore = (*ChunkingConfigurationStore)(nil) +var _ raft.BatchingFSM = (*ChunkingBatchingFSM)(nil) type ChunkingSuccess struct { Response interface{} @@ -28,6 +29,11 @@ type ChunkingFSM struct { lastTerm uint64 } +type ChunkingBatchingFSM struct { + *ChunkingFSM + underlyingBatchingFSM raft.BatchingFSM +} + type ChunkingConfigurationStore struct { *ChunkingFSM underlyingConfigurationStore raft.ConfigurationStore @@ -44,6 +50,20 @@ func NewChunkingFSM(underlying raft.FSM, store ChunkStorage) *ChunkingFSM { return ret } +func NewChunkingBatchingFSM(underlying raft.BatchingFSM, store ChunkStorage) *ChunkingBatchingFSM { + ret := &ChunkingBatchingFSM{ + ChunkingFSM: &ChunkingFSM{ + underlying: underlying, + store: store, + }, + underlyingBatchingFSM: underlying, + } + if store == nil { + ret.ChunkingFSM.store = NewInmemChunkStorage() + } + return ret +} + func NewChunkingConfigurationStore(underlying raft.ConfigurationStore, store ChunkStorage) *ChunkingConfigurationStore { ret := &ChunkingConfigurationStore{ ChunkingFSM: &ChunkingFSM{ @@ -58,14 +78,7 @@ func NewChunkingConfigurationStore(underlying raft.ConfigurationStore, store Chu return ret } -// Apply applies the log, handling chunking as needed. The return value will -// either be an error or whatever is returned from the underlying Apply. -func (c *ChunkingFSM) Apply(l *raft.Log) interface{} { - // Not chunking or wrong type, pass through - if l.Type != raft.LogCommand || l.Extensions == nil { - return c.underlying.Apply(l) - } - +func (c *ChunkingFSM) applyChunk(l *raft.Log) (*raft.Log, error) { if l.Term != c.lastTerm { // Term has changed. A raft library client that was applying chunks // should get an error that it's no longer the leader and bail, and @@ -73,7 +86,7 @@ func (c *ChunkingFSM) Apply(l *raft.Log) interface{} { // chunking operation automatically, which will be under a different // opnum. So it should be safe in this case to clear the map. if err := c.store.RestoreChunks(nil); err != nil { - return err + return nil, err } c.lastTerm = l.Term } @@ -81,7 +94,7 @@ func (c *ChunkingFSM) Apply(l *raft.Log) interface{} { // Get chunk info from extensions var ci types.ChunkInfo if err := proto.Unmarshal(l.Extensions, &ci); err != nil { - return errwrap.Wrapf("error unmarshaling chunk info: {{err}}", err) + return nil, errwrap.Wrapf("error unmarshaling chunk info: {{err}}", err) } // Store the current chunk and find out if all chunks have arrived @@ -93,19 +106,20 @@ func (c *ChunkingFSM) Apply(l *raft.Log) interface{} { Data: l.Data, }) if err != nil { - return err + return nil, err } if !done { - return nil + return nil, nil } // All chunks are here; get the full set and clear storage of the op chunks, err := c.store.FinalizeOp(ci.OpNum) if err != nil { - return err + return nil, err } finalData := make([]byte, 0, len(chunks)*raft.SuggestedMaxDataSize) + for _, chunk := range chunks { finalData = append(finalData, chunk.Data...) } @@ -119,7 +133,27 @@ func (c *ChunkingFSM) Apply(l *raft.Log) interface{} { Extensions: ci.NextExtensions, } - return ChunkingSuccess{Response: c.underlying.Apply(logToApply)} + return logToApply, nil +} + +// Apply applies the log, handling chunking as needed. The return value will +// either be an error or whatever is returned from the underlying Apply. +func (c *ChunkingFSM) Apply(l *raft.Log) interface{} { + // Not chunking or wrong type, pass through + if l.Type != raft.LogCommand || l.Extensions == nil { + return c.underlying.Apply(l) + } + + logToApply, err := c.applyChunk(l) + if err != nil { + return err + } + + if logToApply != nil { + return ChunkingSuccess{Response: c.underlying.Apply(logToApply)} + } + + return nil } func (c *ChunkingFSM) Snapshot() (raft.FSMSnapshot, error) { @@ -157,3 +191,68 @@ func (c *ChunkingFSM) RestoreState(state *State) error { func (c *ChunkingConfigurationStore) StoreConfiguration(index uint64, configuration raft.Configuration) { c.underlyingConfigurationStore.StoreConfiguration(index, configuration) } + +// ApplyBatch applies the logs, handling chunking as needed. The return value will +// be an array containing an error or whatever is returned from the underlying +// Apply for each log. +func (c *ChunkingBatchingFSM) ApplyBatch(logs []*raft.Log) []interface{} { + // responses has a response for each log; their slice index should match. + responses := make([]interface{}, len(logs)) + + // sentLogs keeps track of which logs we sent. The key is the raft Index + // associated with the log and the value is true if this is a finalized set + // of chunks. + sentLogs := make(map[uint64]bool) + + // sendLogs is the subset of logs that we need to pass onto the underlying + // FSM. + sendLogs := make([]*raft.Log, 0, len(logs)) + + for i, l := range logs { + // Not chunking or wrong type, pass through + if l.Type != raft.LogCommand || l.Extensions == nil { + sendLogs = append(sendLogs, l) + sentLogs[l.Index] = false + continue + } + + logToApply, err := c.applyChunk(l) + if err != nil { + responses[i] = err + continue + } + + if logToApply != nil { + sendLogs = append(sendLogs, logToApply) + sentLogs[l.Index] = true + } + } + + // Send remaining logs to the underlying FSM. + var sentResponses []interface{} + if len(sendLogs) > 0 { + sentResponses = c.underlyingBatchingFSM.ApplyBatch(sendLogs) + } + + var sentCounter int + for j, l := range logs { + // If the response is already set we errored above and should continue + // onto the next. + if responses[j] != nil { + continue + } + + var resp interface{} + if chunked, ok := sentLogs[l.Index]; ok { + resp = sentResponses[sentCounter] + if chunked { + resp = ChunkingSuccess{Response: sentResponses[sentCounter]} + } + sentCounter++ + } + + responses[j] = resp + } + + return responses +} diff --git a/fsm_test.go b/fsm_test.go index c361f74..7db602d 100644 --- a/fsm_test.go +++ b/fsm_test.go @@ -8,6 +8,19 @@ import ( "github.com/hashicorp/raft" ) +type MockBatchFSM struct { + *MockFSM +} + +func (m *MockBatchFSM) ApplyBatch(logs []*raft.Log) []interface{} { + responses := make([]interface{}, len(logs)) + for i, l := range logs { + responses[i] = m.Apply(l) + } + + return responses +} + type MockFSM struct { logs [][]byte } @@ -32,7 +45,7 @@ func TestFSM_Basic(t *testing.T) { data, logs := chunkData(t) for i, l := range logs { - r := f.Apply(&l) + r := f.Apply(l) switch r.(type) { case nil: if i == len(logs)-1 { @@ -72,7 +85,7 @@ func TestFSM_StateHandling(t *testing.T) { if i == len(logs)-1 { break } - r := f.Apply(&l) + r := f.Apply(l) switch r.(type) { case nil: case error: @@ -118,7 +131,7 @@ func TestFSM_StateHandling(t *testing.T) { t.Fatal(diff) } - r := f.Apply(&(logs[len(logs)-1])) + r := f.Apply(logs[len(logs)-1]) rRaw, ok := r.(ChunkingSuccess) if !ok { t.Fatalf("wrong type back: %T, value is %#v", r, r) @@ -160,3 +173,94 @@ func TestFSM_StateHandling(t *testing.T) { t.Fatal(diff) } } + +func TestBatchingFSM(t *testing.T) { + m := &MockBatchFSM{ + MockFSM: new(MockFSM), + } + f := NewChunkingBatchingFSM(m, nil) + _, logs := chunkData(t) + + responses := f.ApplyBatch(logs) + for i, r := range responses { + switch r.(type) { + case nil: + if i == len(logs)-1 { + t.Fatal("got nil, expected ChunkingSuccess") + } + case error: + t.Fatal(r.(error)) + case ChunkingSuccess: + if i != len(logs)-1 { + t.Fatal("got int back before apply should have happened") + } + if r.(ChunkingSuccess).Response.(int) != 1 { + t.Fatalf("unexpected number of logs back: %d", r.(int)) + } + default: + t.Fatal("unexpected return value") + } + } +} + +func TestBatchingFSM_MixedData(t *testing.T) { + m := &MockBatchFSM{ + MockFSM: new(MockFSM), + } + f := NewChunkingBatchingFSM(m, nil) + _, logs := chunkData(t) + + lastSeen := 0 + for i := range logs { + batch := make([]*raft.Log, len(logs)) + for j := 0; j < len(logs); j++ { + index := uint64((i * len(logs)) + j) + if i == j { + l := logs[i] + l.Index = index + batch[j] = l + } else { + batch[j] = &raft.Log{ + Index: index, + Data: []byte("test"), + Type: raft.LogCommand, + } + } + } + + responses := f.ApplyBatch(batch) + for j, r := range responses { + switch r.(type) { + case nil: + if j != i { + t.Fatal("got unexpected nil") + } + case error: + t.Fatal(r.(error)) + case int: + if j == i { + t.Fatal("got unexpected int") + } + if r.(int) != lastSeen+1 { + t.Fatalf("unexpected number of logs back: %d, expected %d", r.(int), lastSeen+1) + } + + lastSeen++ + case ChunkingSuccess: + if i != len(logs)-1 && j != i { + t.Fatal("got int back before apply should have happened") + } + if r.(ChunkingSuccess).Response.(int) != lastSeen+1 { + t.Fatalf("unexpected number of logs back: %d", r.(ChunkingSuccess).Response.(int)) + } + lastSeen++ + default: + t.Fatal("unexpected return value") + } + } + } + if lastSeen != 11*12+1 { + t.Fatalf("unexpected total logs processed: %d", lastSeen) + } + +} diff --git a/go.mod b/go.mod index ff3c15b..8e126cd 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/go-test/deep v1.0.2 github.com/golang/protobuf v1.3.1 github.com/hashicorp/errwrap v1.0.0 - github.com/hashicorp/raft v1.1.1 + github.com/hashicorp/raft v1.1.2-0.20191002163536-9c6bd3e3eb17 github.com/kr/pretty v0.1.0 github.com/mitchellh/copystructure v1.0.0 ) diff --git a/go.sum b/go.sum index 9494f6c..9498075 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/hashicorp/golang-lru v0.5.0 h1:CL2msUPvZTLb5O648aiLNJw3hnBxN2+1Jq8rCO github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/raft v1.1.1 h1:HJr7UE1x/JrJSc9Oy6aDBHtNHUUBHjcQjTgvUVihoZs= github.com/hashicorp/raft v1.1.1/go.mod h1:vPAJM8Asw6u8LxC3eJCUZmRP/E4QmUGE1R7g7k8sG/8= +github.com/hashicorp/raft v1.1.2-0.20191002163536-9c6bd3e3eb17 h1:p+2EISNdFCnD9R+B4xCiqSn429MCFtvM41aHJDJ6qW4= +github.com/hashicorp/raft v1.1.2-0.20191002163536-9c6bd3e3eb17/go.mod h1:vPAJM8Asw6u8LxC3eJCUZmRP/E4QmUGE1R7g7k8sG/8= github.com/hashicorp/raft-boltdb v0.0.0-20171010151810-6e5ba93211ea/go.mod h1:pNv7Wc3ycL6F5oOWn+tPGo2gWD4a5X+yp/ntwdKLjRk= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=