From fa63cdfbc20dd73d72c9cc82ee46ce3d5a4c1a10 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 20 Sep 2024 17:03:10 +0100 Subject: [PATCH] Simplify how integration test checks for audit log presence --- integration/integration_test.go | 53 ++++++++++----------------------- 1 file changed, 16 insertions(+), 37 deletions(-) diff --git a/integration/integration_test.go b/integration/integration_test.go index 50294b50b665..6f9f4e5c4261 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -544,45 +544,8 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { } } - // Stream all the session events into a slice to make them easier - // to work with. capturedStream, sessionEvents := streamSession(ctx, t, site, sessionID) - var hasStart bool - var hasEnd bool - var hasLeave bool - for _, se := range sessionEvents { - var isAuditEvent bool - if se.GetType() == events.SessionStartEvent { - isAuditEvent = true - hasStart = true - } - if se.GetType() == events.SessionEndEvent { - isAuditEvent = true - hasEnd = true - } - if se.GetType() == events.SessionLeaveEvent { - isAuditEvent = true - hasLeave = true - } - - // ensure session events are also in audit log - if !isAuditEvent { - continue - } - auditEvents, _, err := site.SearchEvents(ctx, events.SearchEventsRequest{ - To: time.Now(), - EventTypes: []string{se.GetType()}, - }) - require.NoError(t, err) - - found := slices.ContainsFunc(auditEvents, func(ae apievents.AuditEvent) bool { - return ae.GetID() == se.GetID() - }) - require.True(t, found) - } - require.True(t, hasStart && hasEnd && hasLeave) - findByType := func(et string) apievents.AuditEvent { for _, e := range sessionEvents { if e.GetType() == et { @@ -591,6 +554,19 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { } return nil } + // helper that asserts that a session event is also included in the + // general audit log. + requireInAuditLog := func(t *testing.T, sessionEvent apievents.AuditEvent) { + t.Helper() + auditEvents, _, err := site.SearchEvents(ctx, events.SearchEventsRequest{ + To: time.Now(), + EventTypes: []string{sessionEvent.GetType()}, + }) + require.NoError(t, err) + require.True(t, slices.ContainsFunc(auditEvents, func(ae apievents.AuditEvent) bool { + return ae.GetID() == sessionEvent.GetID() + })) + } // there should always be 'session.start' event (and it must be first) first := sessionEvents[0].(*apievents.SessionStart) @@ -598,16 +574,19 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { require.Equal(t, first, start) require.Equal(t, sessionID, start.SessionID) require.NotEmpty(t, start.TerminalSize) + requireInAuditLog(t, start) // there should always be 'session.end' event end := findByType(events.SessionEndEvent).(*apievents.SessionEnd) require.NotNil(t, end) require.Equal(t, sessionID, end.SessionID) + requireInAuditLog(t, end) // there should always be 'session.leave' event leave := findByType(events.SessionLeaveEvent).(*apievents.SessionLeave) require.NotNil(t, leave) require.Equal(t, sessionID, leave.SessionID) + requireInAuditLog(t, leave) // all of them should have a proper time for _, e := range sessionEvents {