diff --git a/app/initstores.go b/app/initstores.go index 87386e5959..6f61312142 100644 --- a/app/initstores.go +++ b/app/initstores.go @@ -220,10 +220,7 @@ func (app *App) initStores(ctx context.Context) error { } if app.IntegrationKeyStore == nil { - app.IntegrationKeyStore, err = integrationkey.NewStore(ctx, app.db) - } - if err != nil { - return errors.Wrap(err, "init integration key store") + app.IntegrationKeyStore = integrationkey.NewStore(ctx, app.db) } if app.ScheduleRuleStore == nil { diff --git a/devtools/pgdump-lite/pgd/db.go b/devtools/pgdump-lite/pgd/db.go index ab6fe88994..f966bc9545 100644 --- a/devtools/pgdump-lite/pgd/db.go +++ b/devtools/pgdump-lite/pgd/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.20.0 +// sqlc v1.21.0 package pgd diff --git a/devtools/pgdump-lite/pgd/models.go b/devtools/pgdump-lite/pgd/models.go index 793235f578..0c17183996 100644 --- a/devtools/pgdump-lite/pgd/models.go +++ b/devtools/pgdump-lite/pgd/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.20.0 +// sqlc v1.21.0 package pgd diff --git a/devtools/pgdump-lite/pgd/queries.sql.go b/devtools/pgdump-lite/pgd/queries.sql.go index 14064dc4a0..207811ca27 100644 --- a/devtools/pgdump-lite/pgd/queries.sql.go +++ b/devtools/pgdump-lite/pgd/queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.20.0 +// sqlc v1.21.0 // source: queries.sql package pgd diff --git a/gadb/queries.sql.go b/gadb/queries.sql.go index 41b3a9c660..54c20df83d 100644 --- a/gadb/queries.sql.go +++ b/gadb/queries.sql.go @@ -718,6 +718,138 @@ func (q *Queries) FindOneCalSubForUpdate(ctx context.Context, id uuid.UUID) (Fin return i, err } +const intKeyCreate = `-- name: IntKeyCreate :exec +INSERT INTO integration_keys(id, name, type, service_id) + VALUES ($1, $2, $3, $4) +` + +type IntKeyCreateParams struct { + ID uuid.UUID + Name string + Type EnumIntegrationKeysType + ServiceID uuid.UUID +} + +func (q *Queries) IntKeyCreate(ctx context.Context, arg IntKeyCreateParams) error { + _, err := q.db.ExecContext(ctx, intKeyCreate, + arg.ID, + arg.Name, + arg.Type, + arg.ServiceID, + ) + return err +} + +const intKeyDelete = `-- name: IntKeyDelete :exec +DELETE FROM integration_keys +WHERE id = ANY ($1::uuid[]) +` + +func (q *Queries) IntKeyDelete(ctx context.Context, ids []uuid.UUID) error { + _, err := q.db.ExecContext(ctx, intKeyDelete, pq.Array(ids)) + return err +} + +const intKeyFindByService = `-- name: IntKeyFindByService :many +SELECT + id, + name, + type, + service_id +FROM + integration_keys +WHERE + service_id = $1 +` + +type IntKeyFindByServiceRow struct { + ID uuid.UUID + Name string + Type EnumIntegrationKeysType + ServiceID uuid.UUID +} + +func (q *Queries) IntKeyFindByService(ctx context.Context, serviceID uuid.UUID) ([]IntKeyFindByServiceRow, error) { + rows, err := q.db.QueryContext(ctx, intKeyFindByService, serviceID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []IntKeyFindByServiceRow + for rows.Next() { + var i IntKeyFindByServiceRow + if err := rows.Scan( + &i.ID, + &i.Name, + &i.Type, + &i.ServiceID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const intKeyFindOne = `-- name: IntKeyFindOne :one +SELECT + id, + name, + type, + service_id +FROM + integration_keys +WHERE + id = $1 +` + +type IntKeyFindOneRow struct { + ID uuid.UUID + Name string + Type EnumIntegrationKeysType + ServiceID uuid.UUID +} + +func (q *Queries) IntKeyFindOne(ctx context.Context, id uuid.UUID) (IntKeyFindOneRow, error) { + row := q.db.QueryRowContext(ctx, intKeyFindOne, id) + var i IntKeyFindOneRow + err := row.Scan( + &i.ID, + &i.Name, + &i.Type, + &i.ServiceID, + ) + return i, err +} + +const intKeyGetServiceID = `-- name: IntKeyGetServiceID :one +SELECT + service_id +FROM + integration_keys +WHERE + id = $1 + AND type = $2 +` + +type IntKeyGetServiceIDParams struct { + ID uuid.UUID + Type EnumIntegrationKeysType +} + +func (q *Queries) IntKeyGetServiceID(ctx context.Context, arg IntKeyGetServiceIDParams) (uuid.UUID, error) { + row := q.db.QueryRowContext(ctx, intKeyGetServiceID, arg.ID, arg.Type) + var service_id uuid.UUID + err := row.Scan(&service_id) + return service_id, err +} + const lockOneAlertService = `-- name: LockOneAlertService :one SELECT maintenance_expires_at NOTNULL::bool AS is_maint_mode, diff --git a/graphql2/graphqlapp/integrationkey.go b/graphql2/graphqlapp/integrationkey.go index 0ec116ed26..e668fb3e70 100644 --- a/graphql2/graphqlapp/integrationkey.go +++ b/graphql2/graphqlapp/integrationkey.go @@ -29,7 +29,7 @@ func (m *Mutation) CreateIntegrationKey(ctx context.Context, input graphql2.Crea Name: input.Name, Type: integrationkey.Type(input.Type), } - key, err = m.IntKeyStore.CreateKeyTx(ctx, tx, key) + key, err = m.IntKeyStore.Create(ctx, tx, key) return err }) return key, err diff --git a/graphql2/graphqlapp/mutation.go b/graphql2/graphqlapp/mutation.go index f1ba1cccdb..e03d6cbc5d 100644 --- a/graphql2/graphqlapp/mutation.go +++ b/graphql2/graphqlapp/mutation.go @@ -261,7 +261,7 @@ func (a *Mutation) tryDeleteAll(ctx context.Context, input []assignment.RawTarge case assignment.TargetTypeEscalationPolicy: err = errors.Wrap(a.PolicyStore.DeleteManyPoliciesTx(ctx, tx, ids), "delete escalation policies") case assignment.TargetTypeIntegrationKey: - err = errors.Wrap(a.IntKeyStore.DeleteManyTx(ctx, tx, ids), "delete integration keys") + err = errors.Wrap(a.IntKeyStore.DeleteMany(ctx, tx, ids), "delete integration keys") case assignment.TargetTypeSchedule: err = errors.Wrap(a.ScheduleStore.DeleteManyTx(ctx, tx, ids), "delete schedules") case assignment.TargetTypeCalendarSubscription: diff --git a/integrationkey/queries.sql b/integrationkey/queries.sql new file mode 100644 index 0000000000..046443d702 --- /dev/null +++ b/integrationkey/queries.sql @@ -0,0 +1,39 @@ +-- name: IntKeyGetServiceID :one +SELECT + service_id +FROM + integration_keys +WHERE + id = $1 + AND type = $2; + +-- name: IntKeyCreate :exec +INSERT INTO integration_keys(id, name, type, service_id) + VALUES ($1, $2, $3, $4); + +-- name: IntKeyFindOne :one +SELECT + id, + name, + type, + service_id +FROM + integration_keys +WHERE + id = $1; + +-- name: IntKeyFindByService :many +SELECT + id, + name, + type, + service_id +FROM + integration_keys +WHERE + service_id = $1; + +-- name: IntKeyDelete :exec +DELETE FROM integration_keys +WHERE id = ANY (@ids::uuid[]); + diff --git a/integrationkey/store.go b/integrationkey/store.go index bb94efd8a9..30bea03e8e 100644 --- a/integrationkey/store.go +++ b/integrationkey/store.go @@ -5,9 +5,8 @@ import ( "database/sql" "github.com/target/goalert/auth/authtoken" + "github.com/target/goalert/gadb" "github.com/target/goalert/permission" - "github.com/target/goalert/util" - "github.com/target/goalert/util/sqlutil" "github.com/target/goalert/validation/validate" "github.com/google/uuid" @@ -16,26 +15,10 @@ import ( type Store struct { db *sql.DB - - getServiceID *sql.Stmt - create *sql.Stmt - findOne *sql.Stmt - findAllByService *sql.Stmt - delete *sql.Stmt } -func NewStore(ctx context.Context, db *sql.DB) (*Store, error) { - p := &util.Prepare{DB: db, Ctx: ctx} - - return &Store{ - db: db, - - getServiceID: p.P("SELECT service_id FROM integration_keys WHERE id = $1 AND type = $2"), - create: p.P("INSERT INTO integration_keys (id, name, type, service_id) VALUES ($1, $2, $3, $4)"), - findOne: p.P("SELECT id, name, type, service_id FROM integration_keys WHERE id = $1"), - findAllByService: p.P("SELECT id, name, type, service_id FROM integration_keys WHERE service_id = $1"), - delete: p.P("DELETE FROM integration_keys WHERE id = any($1)"), - }, p.Err +func NewStore(ctx context.Context, db *sql.DB) *Store { + return &Store{db: db} } func (s *Store) Authorize(ctx context.Context, tok authtoken.Token, t Type) (context.Context, error) { @@ -58,8 +41,9 @@ func (s *Store) Authorize(ctx context.Context, tok authtoken.Token, t Type) (con } func (s *Store) GetServiceID(ctx context.Context, id string, t Type) (string, error) { - err := validate.Many( - validate.UUID("IntegrationKeyID", id), + keyUUID, err := validate.ParseUUID("IntegrationKeyID", id) + err = validate.Many( + err, validate.OneOf("IntegrationType", t, TypeGrafana, TypeSite24x7, TypePrometheusAlertmanager, TypeGeneric, TypeEmail), ) if err != nil { @@ -70,10 +54,11 @@ func (s *Store) GetServiceID(ctx context.Context, id string, t Type) (string, er return "", err } - row := s.getServiceID.QueryRowContext(ctx, id, t) + serviceID, err := gadb.New(s.db).IntKeyGetServiceID(ctx, gadb.IntKeyGetServiceIDParams{ + ID: keyUUID, + Type: gadb.EnumIntegrationKeysType(t), + }) - var serviceID string - err = row.Scan(&serviceID) if errors.Is(err, sql.ErrNoRows) { return "", err } @@ -81,14 +66,10 @@ func (s *Store) GetServiceID(ctx context.Context, id string, t Type) (string, er return "", errors.WithMessage(err, "lookup failure") } - return serviceID, nil + return serviceID.String(), nil } -func (s *Store) Create(ctx context.Context, i *IntegrationKey) (*IntegrationKey, error) { - return s.CreateKeyTx(ctx, nil, i) -} - -func (s *Store) CreateKeyTx(ctx context.Context, tx *sql.Tx, i *IntegrationKey) (*IntegrationKey, error) { +func (s *Store) Create(ctx context.Context, dbtx gadb.DBTX, i *IntegrationKey) (*IntegrationKey, error) { err := permission.LimitCheckAny(ctx, permission.Admin, permission.User) if err != nil { return nil, err @@ -99,47 +80,46 @@ func (s *Store) CreateKeyTx(ctx context.Context, tx *sql.Tx, i *IntegrationKey) return nil, err } - stmt := s.create - if tx != nil { - stmt = tx.Stmt(stmt) + serviceUUID, err := uuid.Parse(n.ServiceID) + if err != nil { + return nil, err } - n.ID = uuid.New().String() - _, err = stmt.ExecContext(ctx, n.ID, n.Name, n.Type, n.ServiceID) + keyUUID := uuid.New() + n.ID = keyUUID.String() + err = gadb.New(dbtx).IntKeyCreate(ctx, gadb.IntKeyCreateParams{ + ID: keyUUID, + Name: n.Name, + Type: gadb.EnumIntegrationKeysType(n.Type), + ServiceID: serviceUUID, + }) if err != nil { return nil, err } return n, nil } -func (s *Store) Delete(ctx context.Context, id string) error { - return s.DeleteTx(ctx, nil, id) -} - -func (s *Store) DeleteTx(ctx context.Context, tx *sql.Tx, id string) error { - return s.DeleteManyTx(ctx, tx, []string{id}) +func (s *Store) Delete(ctx context.Context, dbtx gadb.DBTX, id string) error { + return s.DeleteMany(ctx, dbtx, []string{id}) } -func (s *Store) DeleteManyTx(ctx context.Context, tx *sql.Tx, ids []string) error { +func (s *Store) DeleteMany(ctx context.Context, dbtx gadb.DBTX, ids []string) error { err := permission.LimitCheckAny(ctx, permission.Admin, permission.User) if err != nil { return err } - err = validate.ManyUUID("IntegrationKeyID", ids, 50) + + uuids, err := validate.ParseManyUUID("IntegrationKeyID", ids, 50) if err != nil { return err } - stmt := s.delete - if tx != nil { - stmt = tx.Stmt(stmt) - } - _, err = stmt.ExecContext(ctx, sqlutil.UUIDArray(ids)) + err = gadb.New(dbtx).IntKeyDelete(ctx, uuids) return err } func (s *Store) FindOne(ctx context.Context, id string) (*IntegrationKey, error) { - err := validate.UUID("IntegrationKeyID", id) + keyUUID, err := validate.ParseUUID("IntegrationKeyID", id) if err != nil { return nil, err } @@ -149,9 +129,7 @@ func (s *Store) FindOne(ctx context.Context, id string) (*IntegrationKey, error) return nil, err } - row := s.findOne.QueryRowContext(ctx, id) - var i IntegrationKey - err = scanFrom(&i, row.Scan) + row, err := gadb.New(s.db).IntKeyFindOne(ctx, keyUUID) if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -159,11 +137,16 @@ func (s *Store) FindOne(ctx context.Context, id string) (*IntegrationKey, error) return nil, err } - return &i, nil + return &IntegrationKey{ + ID: row.ID.String(), + Name: row.Name, + Type: Type(row.Type), + ServiceID: row.ServiceID.String(), + }, nil } func (s *Store) FindAllByService(ctx context.Context, serviceID string) ([]IntegrationKey, error) { - err := validate.UUID("ServiceID", serviceID) + serviceUUID, err := validate.ParseUUID("ServiceID", serviceID) if err != nil { return nil, err } @@ -173,26 +156,18 @@ func (s *Store) FindAllByService(ctx context.Context, serviceID string) ([]Integ return nil, err } - rows, err := s.findAllByService.QueryContext(ctx, serviceID) + rows, err := gadb.New(s.db).IntKeyFindByService(ctx, serviceUUID) if err != nil { return nil, err } - defer rows.Close() - return scanAllFrom(rows) -} - -func scanFrom(i *IntegrationKey, f func(args ...interface{}) error) error { - return f(&i.ID, &i.Name, &i.Type, &i.ServiceID) -} - -func scanAllFrom(rows *sql.Rows) (integrationKeys []IntegrationKey, err error) { - var i IntegrationKey - for rows.Next() { - err = scanFrom(&i, rows.Scan) - if err != nil { - return nil, err + keys := make([]IntegrationKey, len(rows)) + for i, row := range rows { + keys[i] = IntegrationKey{ + ID: row.ID.String(), + Name: row.Name, + Type: Type(row.Type), + ServiceID: row.ServiceID.String(), } - integrationKeys = append(integrationKeys, i) } - return integrationKeys, nil + return keys, nil } diff --git a/sqlc.yaml b/sqlc.yaml index f2de6cfd04..086167adfb 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -24,6 +24,7 @@ sql: - engine/statusmgr/queries.sql - auth/authlink/queries.sql - alert/alertlog/queries.sql + - integrationkey/queries.sql - apikey/queries.sql engine: postgresql gen: diff --git a/validation/validate/uuid.go b/validation/validate/uuid.go index cc2efd163c..d295c30907 100644 --- a/validation/validate/uuid.go +++ b/validation/validate/uuid.go @@ -11,17 +11,24 @@ import ( // UUID will validate a UUID, returning a FieldError // if invalid. func UUID(fname, u string) error { + _, err := ParseUUID(fname, u) + return err +} + +// ParseUUID will validate a UUID, returning a FieldError +// if invalid and the parsed UUID otherwise. +func ParseUUID(fname, u string) (uuid.UUID, error) { if len(u) != 36 { // Format check only required to ensure string IDs are valid when being passed to DB. // // We can remove this check once we switch to uuid.UUID in structs everywhere. - return validation.NewFieldError(fname, "must be valid UUID: format must be xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx") + return uuid.UUID{}, validation.NewFieldError(fname, "must be valid UUID: format must be xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx") } - _, err := uuid.Parse(u) + parsed, err := uuid.Parse(u) if err != nil { - return validation.NewFieldError(fname, "must be a valid UUID: "+err.Error()) + return uuid.UUID{}, validation.NewFieldError(fname, "must be a valid UUID: "+err.Error()) } - return nil + return parsed, nil } // NullUUID will validate a UUID, unless Null. It returns a FieldError @@ -36,16 +43,28 @@ func NullUUID(fname string, u sql.NullString) error { // ManyUUID will validate a slice of strings, checking each // with the UUID validator. func ManyUUID(fname string, ids []string, max int) error { + _, err := ParseManyUUID(fname, ids, max) + return err +} + +// ParseManyUUID will validate a slice of strings, checking each +// with the UUID validator, and returning a slice of the parsed UUIDs +// if successful. +func ParseManyUUID(fname string, ids []string, max int) ([]uuid.UUID, error) { if max != -1 && len(ids) > max { - return validation.NewFieldError(fname, "must not have more than "+strconv.Itoa(max)) + return nil, validation.NewFieldError(fname, "must not have more than "+strconv.Itoa(max)) } + uuids := make([]uuid.UUID, len(ids)) errs := make([]error, 0, len(ids)) var err error for i, id := range ids { - err = UUID(fname+"["+strconv.Itoa(i)+"]", id) + uuids[i], err = ParseUUID(fname+"["+strconv.Itoa(i)+"]", id) if err != nil { errs = append(errs, err) } } - return Many(errs...) + if len(errs) > 0 { + return nil, Many(errs...) + } + return uuids, nil } diff --git a/validation/validate/uuid_test.go b/validation/validate/uuid_test.go new file mode 100644 index 0000000000..b706f1fb6f --- /dev/null +++ b/validation/validate/uuid_test.go @@ -0,0 +1,84 @@ +package validate + +import ( + "strings" + "testing" +) + +func TestParseUUID(t *testing.T) { + test := func(valid bool, id string) { + var title string + if valid { + title = "Valid" + } else { + title = "Invalid" + } + t.Run(title, func(t *testing.T) { + parsed, err := ParseUUID("UUID", id) + if err == nil && !valid { + t.Errorf("ParseUUID(%s) err = nil; want error", id) + } else if err != nil && valid { + t.Errorf("ParseUUID(%s) err = %v; want nil", id, err) + } else if valid && !strings.EqualFold(parsed.String(), id) { + t.Errorf("ParseUUID(%s) parsed = %s; want %s", id, parsed.String(), id) + } + }) + } + + invalid := []string{ + "", "12345", "b8b3ee1d-5ff8-4751-9$08-cb3e8b214790", "b8b3ee1d5ff847519208cb3e8b214790", + } + valid := []string{ + "b8b3ee1d-5ff8-4751-9208-cb3e8b214790", "00000000-0000-0000-0000-000000000000", "B8B3EE1D-5FF8-4751-9208-cb3e8b214790", + } + for _, n := range invalid { + test(false, n) + } + for _, n := range valid { + test(true, n) + } +} + +func TestParseManyUUID(t *testing.T) { + maxLength := 3 + test := func(valid bool, ids []string) { + var title string + if valid { + title = "Valid" + } else { + title = "Invalid" + } + t.Run(title, func(t *testing.T) { + parsed, err := ParseManyUUID("UUID", ids, maxLength) + if err == nil && !valid { + t.Errorf("ParseManyUUID(%v) err = nil; want error", ids) + } else if err != nil && valid { + t.Errorf("ParseUUID(%v) err = %v; want nil", ids, err) + } else if valid { + for i, p := range parsed { + if !strings.EqualFold(p.String(), ids[i]) { + t.Errorf("ParseUUID(%v) parsed[%d] = %s; want %s", ids, i, p.String(), ids[i]) + } + } + } + }) + } + + invalid := [][]string{ + {"b8b3ee1d-5ff8-4751-9208-cb3e8b214790", "", "00000000-0000-0000-0000-000000000000"}, + {"b8b3ee1d-5ff8-4751-9208-cb3e8b214790", "12345", "00000000-0000-0000-0000-000000000000"}, + {"b8b3ee1d-5ff8-4751-9208-cb3e8b214790", "00000000-0000-0000-0000-000000000000", "00000000-0000-0000-0000-000000000001", "00000000-0000-0000-0000-000000000002"}, + } + valid := [][]string{ + {"B8B3EE1D-5FF8-4751-9208-cb3e8b214790"}, + {"b8b3ee1d-5ff8-4751-9208-cb3e8b214790", "00000000-0000-0000-0000-000000000000", "00000000-0000-0000-0000-000000000001"}, + {"b8b3ee1d-5ff8-4751-9208-cb3e8b214790"}, + {}, + } + for _, n := range invalid { + test(false, n) + } + for _, n := range valid { + test(true, n) + } +}