diff --git a/modules/modules.go b/modules/modules.go index 4aee08d21..60e68ca25 100644 --- a/modules/modules.go +++ b/modules/modules.go @@ -59,11 +59,25 @@ func (m *Manager) RegisterModule(name string, initFn func() (services.Service, e // AddDependency adds a dependency from name(source) to dependsOn(targets) // An error is returned if the source module name is not found func (m *Manager) AddDependency(name string, dependsOn ...string) error { - if mod, ok := m.modules[name]; ok { - mod.deps = append(mod.deps, dependsOn...) - } else { + mod, ok := m.modules[name] + if !ok { return fmt.Errorf("no such module: %s", name) } + + // Ensure it doesn't introduce any circular dependency. + for _, newDep := range dependsOn { + if _, ok := m.modules[newDep]; !ok { + return fmt.Errorf("no such module: %s", newDep) + } + + for _, prevDep := range m.DependenciesForModule(newDep) { + if prevDep == name { + return fmt.Errorf("found a circular dependency: %s depends on %s", newDep, name) + } + } + } + + mod.deps = append(mod.deps, dependsOn...) return nil } @@ -92,7 +106,7 @@ func (m *Manager) initModule(name string, initMap map[string]bool, servicesMap m deps := m.orderedDeps(name) deps = append(deps, name) // lastly, initialize the requested module - for ix, n := range deps { + for _, n := range deps { // Skip already initialized modules if initMap[n] { continue @@ -111,7 +125,7 @@ func (m *Manager) initModule(name string, initMap map[string]bool, servicesMap m if s != nil { // We pass servicesMap, which isn't yet complete. By the time service starts, // it will be fully built, so there is no need for extra synchronization. - serv = newModuleServiceWrapper(servicesMap, n, m.logger, s, m.DependenciesForModule(n), m.findInverseDependencies(n, deps[ix+1:])) + serv = newModuleServiceWrapper(servicesMap, n, m.logger, s, m.DependenciesForModule(n), m.inverseDependenciesForModule(n)) } } @@ -205,12 +219,12 @@ func (m *Manager) orderedDeps(mod string) []string { return result } -// find modules in the supplied list, that depend on mod -func (m *Manager) findInverseDependencies(mod string, mods []string) []string { +// inverseDependenciesForModule returns the list of modules depending on the input module, sorted by name. +func (m *Manager) inverseDependenciesForModule(mod string) []string { result := []string(nil) - for _, n := range mods { - for _, d := range m.modules[n].deps { + for n := range m.modules { + for _, d := range m.DependenciesForModule(n) { if d == mod { result = append(result, n) break @@ -218,6 +232,7 @@ func (m *Manager) findInverseDependencies(mod string, mods []string) []string { } } + sort.Strings(result) return result } diff --git a/modules/modules_test.go b/modules/modules_test.go index 13b041f3f..a376204ba 100644 --- a/modules/modules_test.go +++ b/modules/modules_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sort" "testing" "time" @@ -45,9 +46,8 @@ func TestDependencies(t *testing.T) { assert.NoError(t, mm.AddDependency("serviceC", "serviceB")) assert.Equal(t, mm.modules["serviceB"].deps, []string{"serviceA"}) - invDeps := mm.findInverseDependencies("serviceA", []string{"serviceB", "serviceC"}) - require.Len(t, invDeps, 1) - assert.Equal(t, invDeps[0], "serviceB") + invDeps := mm.inverseDependenciesForModule("serviceA") + assert.Equal(t, []string{"serviceB", "serviceC"}, invDeps) // Test unknown module svc, err := mm.InitModuleServices("service_unknown") @@ -63,16 +63,83 @@ func TestDependencies(t *testing.T) { svc, err = mm.InitModuleServices("serviceA", "serviceB") assert.Nil(t, err) assert.Equal(t, 2, len(svc)) + assert.Equal(t, []string{"serviceB"}, getStopDependenciesForModule("serviceA", svc)) + assert.Equal(t, []string(nil), getStopDependenciesForModule("serviceB", svc)) svc, err = mm.InitModuleServices("serviceC") assert.NoError(t, err) assert.Equal(t, 3, len(svc)) + assert.Equal(t, []string{"serviceB", "serviceC"}, getStopDependenciesForModule("serviceA", svc)) + assert.Equal(t, []string{"serviceC"}, getStopDependenciesForModule("serviceB", svc)) + assert.Equal(t, []string(nil), getStopDependenciesForModule("serviceC", svc)) // Test loading of the module second time - should produce the same set of services, but new instances. svc2, err := mm.InitModuleServices("serviceC") assert.NoError(t, err) assert.Equal(t, 3, len(svc)) assert.NotEqual(t, svc, svc2) + assert.Equal(t, []string{"serviceB", "serviceC"}, getStopDependenciesForModule("serviceA", svc)) + assert.Equal(t, []string{"serviceC"}, getStopDependenciesForModule("serviceB", svc)) + assert.Equal(t, []string(nil), getStopDependenciesForModule("serviceC", svc)) +} + +func TestManaged_AddDependency_ShouldErrorOnCircularDependencies(t *testing.T) { + var testModules = map[string]module{ + "serviceA": { + initFn: mockInitFunc, + }, + + "serviceB": { + initFn: mockInitFunc, + }, + + "serviceC": { + initFn: mockInitFunc, + }, + } + + mm := NewManager(log.NewNopLogger()) + for name, mod := range testModules { + mm.RegisterModule(name, mod.initFn) + } + assert.NoError(t, mm.AddDependency("serviceA", "serviceB")) + assert.NoError(t, mm.AddDependency("serviceB", "serviceC")) + + // Direct circular dependency. + err := mm.AddDependency("serviceB", "serviceA") + require.Error(t, err) + assert.Contains(t, err.Error(), "circular dependency") + + // Indirect circular dependency. + err = mm.AddDependency("serviceC", "serviceA") + require.Error(t, err) + assert.Contains(t, err.Error(), "circular dependency") +} + +func TestManaged_AddDependency_ShouldErrorIfModuleDoesNotExist(t *testing.T) { + var testModules = map[string]module{ + "serviceA": { + initFn: mockInitFunc, + }, + + "serviceB": { + initFn: mockInitFunc, + }, + } + + mm := NewManager(log.NewNopLogger()) + for name, mod := range testModules { + mm.RegisterModule(name, mod.initFn) + } + assert.NoError(t, mm.AddDependency("serviceA", "serviceB")) + + // Module does not exist. + err := mm.AddDependency("serviceUnknown", "serviceA") + assert.EqualError(t, err, "no such module: serviceUnknown") + + // Dependency does not exist. + err = mm.AddDependency("serviceA", "serviceUnknown") + assert.EqualError(t, err, "no such module: serviceUnknown") } func TestRegisterModuleDefaultsToUserVisible(t *testing.T) { @@ -168,7 +235,7 @@ func TestIsModuleRegistered(t *testing.T) { assert.False(t, result, "module '%v' should NOT be registered", failureModule) } -func TestDependenciesForModule(t *testing.T) { +func TestManager_DependenciesForModule(t *testing.T) { m := NewManager(log.NewNopLogger()) m.RegisterModule("test", nil) m.RegisterModule("dep1", nil) @@ -183,6 +250,30 @@ func TestDependenciesForModule(t *testing.T) { assert.Equal(t, []string{"dep1", "dep2", "dep3"}, deps) } +func TestManager_inverseDependenciesForModule(t *testing.T) { + m := NewManager(log.NewNopLogger()) + m.RegisterModule("test", nil) + m.RegisterModule("dep1", nil) + m.RegisterModule("dep2", nil) + m.RegisterModule("dep3", nil) + + require.NoError(t, m.AddDependency("test", "dep2", "dep1")) + require.NoError(t, m.AddDependency("dep1", "dep2")) + require.NoError(t, m.AddDependency("dep2", "dep3")) + + invDeps := m.inverseDependenciesForModule("test") + assert.Equal(t, []string(nil), invDeps) + + invDeps = m.inverseDependenciesForModule("dep1") + assert.Equal(t, []string{"test"}, invDeps) + + invDeps = m.inverseDependenciesForModule("dep2") + assert.Equal(t, []string{"dep1", "test"}, invDeps) + + invDeps = m.inverseDependenciesForModule("dep3") + assert.Equal(t, []string{"dep1", "dep2", "test"}, invDeps) +} + func TestModuleWaitsForAllDependencies(t *testing.T) { var serviceA services.Service @@ -230,3 +321,13 @@ func TestModuleWaitsForAllDependencies(t *testing.T) { assert.NoError(t, services.StartManagerAndAwaitHealthy(context.Background(), servManager)) assert.NoError(t, services.StopManagerAndAwaitStopped(context.Background(), servManager)) } + +func getStopDependenciesForModule(module string, services map[string]services.Service) []string { + var deps []string + for name := range services[module].(*moduleService).stopDeps(module) { + deps = append(deps, name) + } + + sort.Strings(deps) + return deps +}