Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VAULT-25710: Audit - enforce header formatter requirement in EntryFormatter #26239

Merged
merged 7 commits into from
Apr 3, 2024
61 changes: 29 additions & 32 deletions audit/entry_formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,14 @@ type timeProvider interface {

// EntryFormatter should be used to format audit requests and responses.
type EntryFormatter struct {
config FormatterConfig
salter Salter
logger hclog.Logger
headerFormatter HeaderFormatter
name string
prefix string
config FormatterConfig
salter Salter
logger hclog.Logger
name string
}

// NewEntryFormatter should be used to create an EntryFormatter.
// Accepted options: WithHeaderFormatter, WithPrefix.
func NewEntryFormatter(name string, config FormatterConfig, salter Salter, logger hclog.Logger, opt ...Option) (*EntryFormatter, error) {
func NewEntryFormatter(name string, config FormatterConfig, salter Salter, logger hclog.Logger) (*EntryFormatter, error) {
const op = "audit.NewEntryFormatter"

name = strings.TrimSpace(name)
Expand All @@ -69,18 +66,11 @@ func NewEntryFormatter(name string, config FormatterConfig, salter Salter, logge
return nil, fmt.Errorf("%s: format not valid: %w", op, err)
}

opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("%s: error applying options: %w", op, err)
}

return &EntryFormatter{
config: config,
salter: salter,
logger: logger,
headerFormatter: opts.withHeaderFormatter,
name: name,
prefix: opts.withPrefix,
config: config,
salter: salter,
logger: logger,
name: name,
}, nil
}

Expand Down Expand Up @@ -145,11 +135,14 @@ func (f *EntryFormatter) Process(ctx context.Context, e *eventlogger.Event) (_ *
return nil, fmt.Errorf("%s: unable to copy audit event data: %w", op, err)
}

// Ensure that any headers in the request, are formatted as required, and are
// only present if they have been configured to appear in the audit log.
// e.g. via: /sys/config/auditing/request-headers/:name
if f.headerFormatter != nil && data.Request != nil && data.Request.Headers != nil {
data.Request.Headers, err = f.headerFormatter.ApplyConfig(ctx, data.Request.Headers, f.salter)
// If the request is present in the input data, apply header configuration
// regardless. We shouldn't be in a situation where the header formatter isn't
// present as it's required.
if data.Request != nil {
// Ensure that any headers in the request, are formatted as required, and are
// only present if they have been configured to appear in the audit log.
// e.g. via: /sys/config/auditing/request-headers/:name
data.Request.Headers, err = f.config.headerFormatter.ApplyConfig(ctx, data.Request.Headers, f.salter)
if err != nil {
return nil, fmt.Errorf("%s: unable to transform headers for auditing: %w", op, err)
}
Expand Down Expand Up @@ -198,8 +191,8 @@ func (f *EntryFormatter) Process(ctx context.Context, e *eventlogger.Event) (_ *
// don't support a prefix just sitting there.
// However, this would be a breaking change to how Vault currently works to
// include the prefix as part of the JSON object or XML document.
if f.prefix != "" {
result = append([]byte(f.prefix), result...)
if f.config.Prefix != "" {
result = append([]byte(f.config.Prefix), result...)
}

// Copy some properties from the event (and audit event) and store the
Expand Down Expand Up @@ -577,19 +570,25 @@ func (f *EntryFormatter) FormatResponse(ctx context.Context, in *logical.LogInpu
}

// NewFormatterConfig should be used to create a FormatterConfig.
// Accepted options: WithElision, WithHMACAccessor, WithOmitTime, WithRaw, WithFormat.
func NewFormatterConfig(opt ...Option) (FormatterConfig, error) {
// Accepted options: WithElision, WithFormat, WithHMACAccessor, WithOmitTime, WithPrefix, WithRaw.
func NewFormatterConfig(headerFormatter HeaderFormatter, opt ...Option) (FormatterConfig, error) {
const op = "audit.NewFormatterConfig"

if headerFormatter == nil || reflect.ValueOf(headerFormatter).IsNil() {
return FormatterConfig{}, fmt.Errorf("%s: header formatter is required: %w", op, event.ErrInvalidParameter)
}

opts, err := getOpts(opt...)
if err != nil {
return FormatterConfig{}, fmt.Errorf("%s: error applying options: %w", op, err)
}

return FormatterConfig{
headerFormatter: headerFormatter,
ElideListResponses: opts.withElision,
HMACAccessor: opts.withHMACAccessor,
OmitTime: opts.withOmitTime,
Prefix: opts.withPrefix,
Raw: opts.withRaw,
RequiredFormat: opts.withFormat,
}, nil
Expand Down Expand Up @@ -663,10 +662,8 @@ func doElideListResponseData(data map[string]interface{}) {
// newTemporaryEntryFormatter creates a cloned EntryFormatter instance with a non-persistent Salter.
func newTemporaryEntryFormatter(n *EntryFormatter) *EntryFormatter {
return &EntryFormatter{
salter: &nonPersistentSalt{},
headerFormatter: n.headerFormatter,
config: n.config,
prefix: n.prefix,
salter: &nonPersistentSalt{},
config: n.config,
}
}

Expand Down
105 changes: 87 additions & 18 deletions audit/entry_formatter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,24 @@ const testFormatJSONReqBasicStrFmt = `
}
`

// testHeaderFormatter is a stub to prevent the need to import the vault package
// to bring in vault.AuditedHeadersConfig for testing.
type testHeaderFormatter struct {
shouldReturnEmpty bool
}

// ApplyConfig satisfies the HeaderFormatter interface for testing.
// It will either return the headers it was supplied or empty headers depending
// on how it is configured.
// ignore-nil-nil-function-check.
func (f *testHeaderFormatter) ApplyConfig(_ context.Context, headers map[string][]string, salter Salter) (result map[string][]string, retErr error) {
if f.shouldReturnEmpty {
return make(map[string][]string), nil
}

return headers, nil
}

// testTimeProvider is just a test struct used to imitate an AuditEvent's ability
// to provide a formatted time.
type testTimeProvider struct{}
Expand Down Expand Up @@ -178,9 +196,9 @@ func TestNewEntryFormatter(t *testing.T) {
ss = newStaticSalt(t)
}

cfg, err := NewFormatterConfig(tc.Options...)
cfg, err := NewFormatterConfig(&testHeaderFormatter{}, tc.Options...)
require.NoError(t, err)
f, err := NewEntryFormatter(tc.Name, cfg, ss, tc.Logger, tc.Options...)
f, err := NewEntryFormatter(tc.Name, cfg, ss, tc.Logger /*, tc.Options...*/)
peteski22 marked this conversation as resolved.
Show resolved Hide resolved

switch {
case tc.IsErrorExpected:
Expand All @@ -191,7 +209,7 @@ func TestNewEntryFormatter(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, f)
require.Equal(t, tc.ExpectedFormat, f.config.RequiredFormat)
require.Equal(t, tc.ExpectedPrefix, f.prefix)
require.Equal(t, tc.ExpectedPrefix, f.config.Prefix)
}
})
}
Expand All @@ -202,7 +220,7 @@ func TestEntryFormatter_Reopen(t *testing.T) {
t.Parallel()

ss := newStaticSalt(t)
cfg, err := NewFormatterConfig()
cfg, err := NewFormatterConfig(&testHeaderFormatter{})
require.NoError(t, err)

f, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
Expand All @@ -216,7 +234,7 @@ func TestEntryFormatter_Type(t *testing.T) {
t.Parallel()

ss := newStaticSalt(t)
cfg, err := NewFormatterConfig()
cfg, err := NewFormatterConfig(&testHeaderFormatter{})
require.NoError(t, err)

f, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
Expand Down Expand Up @@ -361,7 +379,7 @@ func TestEntryFormatter_Process(t *testing.T) {
require.NotNil(t, e)

ss := newStaticSalt(t)
cfg, err := NewFormatterConfig(WithFormat(tc.RequiredFormat.String()))
cfg, err := NewFormatterConfig(&testHeaderFormatter{}, WithFormat(tc.RequiredFormat.String()))
require.NoError(t, err)

f, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
Expand Down Expand Up @@ -426,7 +444,7 @@ func BenchmarkAuditFileSink_Process(b *testing.B) {
ctx := namespace.RootContext(context.Background())

// Create the formatter node.
cfg, err := NewFormatterConfig()
cfg, err := NewFormatterConfig(&testHeaderFormatter{})
require.NoError(b, err)
ss := newStaticSalt(b)
formatter, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
Expand Down Expand Up @@ -504,7 +522,7 @@ func TestEntryFormatter_FormatRequest(t *testing.T) {
t.Parallel()

ss := newStaticSalt(t)
cfg, err := NewFormatterConfig(WithOmitTime(tc.ShouldOmitTime))
cfg, err := NewFormatterConfig(&testHeaderFormatter{}, WithOmitTime(tc.ShouldOmitTime))
require.NoError(t, err)
f, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
require.NoError(t, err)
Expand Down Expand Up @@ -586,7 +604,7 @@ func TestEntryFormatter_FormatResponse(t *testing.T) {
t.Parallel()

ss := newStaticSalt(t)
cfg, err := NewFormatterConfig(WithOmitTime(tc.ShouldOmitTime))
cfg, err := NewFormatterConfig(&testHeaderFormatter{}, WithOmitTime(tc.ShouldOmitTime))
require.NoError(t, err)
f, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
require.NoError(t, err)
Expand Down Expand Up @@ -702,9 +720,9 @@ func TestEntryFormatter_Process_JSON(t *testing.T) {
}

for name, tc := range cases {
cfg, err := NewFormatterConfig(WithHMACAccessor(false))
cfg, err := NewFormatterConfig(&testHeaderFormatter{}, WithHMACAccessor(false), WithPrefix(tc.Prefix))
require.NoError(t, err)
formatter, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger(), WithPrefix(tc.Prefix))
formatter, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
require.NoError(t, err)

in := &logical.LogInput{
Expand Down Expand Up @@ -860,12 +878,14 @@ func TestEntryFormatter_Process_JSONx(t *testing.T) {

for name, tc := range cases {
cfg, err := NewFormatterConfig(
&testHeaderFormatter{},
WithOmitTime(true),
WithHMACAccessor(false),
WithFormat(JSONxFormat.String()),
WithPrefix(tc.Prefix),
)
require.NoError(t, err)
formatter, err := NewEntryFormatter("juan", cfg, tempStaticSalt, hclog.NewNullLogger(), WithPrefix(tc.Prefix))
formatter, err := NewEntryFormatter("juan", cfg, tempStaticSalt, hclog.NewNullLogger())
require.NoError(t, err)
require.NotNil(t, formatter)

Expand Down Expand Up @@ -997,7 +1017,7 @@ func TestEntryFormatter_FormatResponse_ElideListResponses(t *testing.T) {
}

t.Run("Default case", func(t *testing.T) {
config, err := NewFormatterConfig(WithElision(true))
config, err := NewFormatterConfig(&testHeaderFormatter{}, WithElision(true))
require.NoError(t, err)
for name, tc := range tests {
name := name
Expand All @@ -1010,23 +1030,23 @@ func TestEntryFormatter_FormatResponse_ElideListResponses(t *testing.T) {
})

t.Run("When Operation is not list, eliding does not happen", func(t *testing.T) {
config, err := NewFormatterConfig(WithElision(true))
config, err := NewFormatterConfig(&testHeaderFormatter{}, WithElision(true))
require.NoError(t, err)
tc := oneInterestingTestCase
entry := format(t, config, logical.ReadOperation, tc.inputData)
assert.Equal(t, formatter.hashExpectedValueForComparison(tc.inputData), entry.Response.Data)
})

t.Run("When ElideListResponses is false, eliding does not happen", func(t *testing.T) {
config, err := NewFormatterConfig(WithElision(false), WithFormat(JSONFormat.String()))
config, err := NewFormatterConfig(&testHeaderFormatter{}, WithElision(false), WithFormat(JSONFormat.String()))
require.NoError(t, err)
tc := oneInterestingTestCase
entry := format(t, config, logical.ListOperation, tc.inputData)
assert.Equal(t, formatter.hashExpectedValueForComparison(tc.inputData), entry.Response.Data)
})

t.Run("When Raw is true, eliding still happens", func(t *testing.T) {
config, err := NewFormatterConfig(WithElision(true), WithRaw(true), WithFormat(JSONFormat.String()))
config, err := NewFormatterConfig(&testHeaderFormatter{}, WithElision(true), WithRaw(true), WithFormat(JSONFormat.String()))
require.NoError(t, err)
tc := oneInterestingTestCase
entry := format(t, config, logical.ListOperation, tc.inputData)
Expand All @@ -1040,7 +1060,7 @@ func TestEntryFormatter_Process_NoMutation(t *testing.T) {
t.Parallel()

// Create the formatter node.
cfg, err := NewFormatterConfig()
cfg, err := NewFormatterConfig(&testHeaderFormatter{})
require.NoError(t, err)
ss := newStaticSalt(t)
formatter, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
Expand Down Expand Up @@ -1100,7 +1120,7 @@ func TestEntryFormatter_Process_Panic(t *testing.T) {
t.Parallel()

// Create the formatter node.
cfg, err := NewFormatterConfig()
cfg, err := NewFormatterConfig(&testHeaderFormatter{})
require.NoError(t, err)
ss := newStaticSalt(t)
formatter, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
Expand Down Expand Up @@ -1153,6 +1173,55 @@ func TestEntryFormatter_Process_Panic(t *testing.T) {
require.Nil(t, e2)
}

// TestEntryFormatter_NewFormatterConfig_NilHeaderFormatter ensures we cannot
// create a FormatterConfig using NewFormatterConfig if we supply a nil formatter.
func TestEntryFormatter_NewFormatterConfig_NilHeaderFormatter(t *testing.T) {
_, err := NewFormatterConfig(nil)
require.Error(t, err)
}

// TestEntryFormatter_Process_NeverLeaksHeaders ensures that if we never accidentally
// leak headers if applying them means we don't have any. This is more like a sense
// check to ensure the returned event doesn't somehow end up with the headers 'back'.
func TestEntryFormatter_Process_NeverLeaksHeaders(t *testing.T) {
t.Parallel()

// Create the formatter node.
cfg, err := NewFormatterConfig(&testHeaderFormatter{shouldReturnEmpty: true})
require.NoError(t, err)
ss := newStaticSalt(t)
formatter, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
require.NoError(t, err)
require.NotNil(t, formatter)

// Set up the input and verify we have a single foo:bar header.
var input *logical.LogInput
err = json.Unmarshal([]byte(testFormatJSONReqBasicStrFmt), &input)
require.NoError(t, err)
require.NotNil(t, input)
require.Len(t, input.Request.Headers, 1)
require.Len(t, input.Request.Headers["foo"], 1)
peteski22 marked this conversation as resolved.
Show resolved Hide resolved
require.Equal(t, "bar", input.Request.Headers["foo"][0])

e := fakeEvent(t, RequestType, input)

// Process the node.
ctx := namespace.RootContext(context.Background())
e2, err := formatter.Process(ctx, e)
require.NoError(t, err)
require.NotNil(t, e2)

// Now check we can retrieve the formatted JSON.
jsonFormatted, b2 := e2.Format(JSONFormat.String())
require.True(t, b2)
require.NotNil(t, jsonFormatted)
var input2 *logical.LogInput
err = json.Unmarshal(jsonFormatted, &input2)
require.NoError(t, err)
require.NotNil(t, input2)
require.Len(t, input2.Request.Headers, 0)
}

// hashExpectedValueForComparison replicates enough of the audit HMAC process on a piece of expected data in a test,
// so that we can use assert.Equal to compare the expected and output values.
func (f *EntryFormatter) hashExpectedValueForComparison(input map[string]any) map[string]any {
Expand Down
Loading
Loading