Skip to content

Commit

Permalink
Simplify how integration test checks for audit log presence
Browse files Browse the repository at this point in the history
  • Loading branch information
strideynet committed Sep 20, 2024
1 parent e4c007c commit fa63cdf
Showing 1 changed file with 16 additions and 37 deletions.
53 changes: 16 additions & 37 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -544,45 +544,8 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) {
}
}

// Stream all the session events into a slice to make them easier
// to work with.
capturedStream, sessionEvents := streamSession(ctx, t, site, sessionID)

var hasStart bool
var hasEnd bool
var hasLeave bool
for _, se := range sessionEvents {
var isAuditEvent bool
if se.GetType() == events.SessionStartEvent {
isAuditEvent = true
hasStart = true
}
if se.GetType() == events.SessionEndEvent {
isAuditEvent = true
hasEnd = true
}
if se.GetType() == events.SessionLeaveEvent {
isAuditEvent = true
hasLeave = true
}

// ensure session events are also in audit log
if !isAuditEvent {
continue
}
auditEvents, _, err := site.SearchEvents(ctx, events.SearchEventsRequest{
To: time.Now(),
EventTypes: []string{se.GetType()},
})
require.NoError(t, err)

found := slices.ContainsFunc(auditEvents, func(ae apievents.AuditEvent) bool {
return ae.GetID() == se.GetID()
})
require.True(t, found)
}
require.True(t, hasStart && hasEnd && hasLeave)

findByType := func(et string) apievents.AuditEvent {
for _, e := range sessionEvents {
if e.GetType() == et {
Expand All @@ -591,23 +554,39 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) {
}
return nil
}
// helper that asserts that a session event is also included in the
// general audit log.
requireInAuditLog := func(t *testing.T, sessionEvent apievents.AuditEvent) {
t.Helper()
auditEvents, _, err := site.SearchEvents(ctx, events.SearchEventsRequest{
To: time.Now(),
EventTypes: []string{sessionEvent.GetType()},
})
require.NoError(t, err)
require.True(t, slices.ContainsFunc(auditEvents, func(ae apievents.AuditEvent) bool {
return ae.GetID() == sessionEvent.GetID()
}))
}

// there should always be 'session.start' event (and it must be first)
first := sessionEvents[0].(*apievents.SessionStart)
start := findByType(events.SessionStartEvent).(*apievents.SessionStart)
require.Equal(t, first, start)
require.Equal(t, sessionID, start.SessionID)
require.NotEmpty(t, start.TerminalSize)
requireInAuditLog(t, start)

// there should always be 'session.end' event
end := findByType(events.SessionEndEvent).(*apievents.SessionEnd)
require.NotNil(t, end)
require.Equal(t, sessionID, end.SessionID)
requireInAuditLog(t, end)

// there should always be 'session.leave' event
leave := findByType(events.SessionLeaveEvent).(*apievents.SessionLeave)
require.NotNil(t, leave)
require.Equal(t, sessionID, leave.SessionID)
requireInAuditLog(t, leave)

// all of them should have a proper time
for _, e := range sessionEvents {
Expand Down

0 comments on commit fa63cdf

Please sign in to comment.