diff --git a/.golangci.yml b/.golangci.yml index 536c1b6552..8114945c6f 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -93,11 +93,13 @@ linters-settings: - 'errors.Wrap' gomoddirectives: + replace-local: true replace-allow-list: # See go.mod for the explanation why these are needed. - github.com/ulikunitz/xz - github.com/gogo/protobuf - google.golang.org/protobuf + - github.com/lightningnetwork/lnd/sqldb linters: diff --git a/channeldb/invoices.go b/channeldb/invoices.go index df124b632a..9da504a5d8 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -269,7 +269,9 @@ func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) ( // For each key found, we'll look up the actual // invoice, then accumulate it into our return value. - invoice, err := fetchInvoice(invoiceKey, invoices) + invoice, err := fetchInvoice( + invoiceKey, invoices, nil, false, + ) if err != nil { return err } @@ -341,7 +343,9 @@ func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) ( // An invoice was found, retrieve the remainder of the invoice // body. - i, err := fetchInvoice(invoiceNum, invoices, setID) + i, err := fetchInvoice( + invoiceNum, invoices, []*invpkg.SetID{setID}, true, + ) if err != nil { return err } @@ -468,7 +472,7 @@ func (d *DB) FetchPendingInvoices(_ context.Context) ( return nil } - invoice, err := fetchInvoice(v, invoices) + invoice, err := fetchInvoice(v, invoices, nil, false) if err != nil { return err } @@ -526,7 +530,9 @@ func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) ( // characteristics for our query and returns the number of items // we have added to our set of invoices. accumulateInvoices := func(_, indexValue []byte) (bool, error) { - invoice, err := fetchInvoice(indexValue, invoices) + invoice, err := fetchInvoice( + indexValue, invoices, nil, false, + ) if err != nil { return false, err } @@ -654,7 +660,9 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef, if setIDHint != nil { invSetID = *setIDHint } - invoice, err := fetchInvoice(invoiceNum, invoices, &invSetID) + invoice, err := fetchInvoice( + invoiceNum, invoices, []*invpkg.SetID{&invSetID}, false, + ) if err != nil { return err } @@ -676,8 +684,17 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef, updatedInvoice, err = invpkg.UpdateInvoice( payHash, updater.invoice, now, callback, updater, ) + if err != nil { + return err + } - return err + // If this is an AMP update, then limit the returned AMP state + // to only the requested set ID. + if setIDHint != nil { + filterInvoiceAMPState(updatedInvoice, &invSetID) + } + + return nil }, func() { updatedInvoice = nil }) @@ -685,6 +702,25 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef, return updatedInvoice, err } +// filterInvoiceAMPState filters the AMP state of the invoice to only include +// state for the specified set IDs. +func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) { + filteredAMPState := make(invpkg.AMPInvoiceState) + + for _, setID := range setIDs { + if setID == nil { + return + } + + ampState, ok := invoice.AMPState[*setID] + if ok { + filteredAMPState[*setID] = ampState + } + } + + invoice.AMPState = filteredAMPState +} + // ampHTLCsMap is a map of AMP HTLCs affected by an invoice update. type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC @@ -1056,7 +1092,8 @@ func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) ( // For each key found, we'll look up the actual // invoice, then accumulate it into our return value. invoice, err := fetchInvoice( - invoiceKey[:], invoices, setID, + invoiceKey[:], invoices, []*invpkg.SetID{setID}, + true, ) if err != nil { return err @@ -1485,7 +1522,7 @@ func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte, // specified by the invoice number. If the setID fields are set, then only the // HTLC information pertaining to those set IDs is returned. func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket, - setIDs ...*invpkg.SetID) (invpkg.Invoice, error) { + setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) { invoiceBytes := invoices.Get(invoiceNum) if invoiceBytes == nil { @@ -1518,6 +1555,10 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket, log.Errorf("unable to fetch amp htlcs for inv "+ "%v and setIDs %v: %w", invoiceNum, setIDs, err) } + + if filterAMPState { + filterInvoiceAMPState(&invoice, setIDs...) + } } return invoice, nil @@ -2163,7 +2204,7 @@ func (d *DB) DeleteCanceledInvoices(_ context.Context) error { return nil } - invoice, err := fetchInvoice(v, invoices) + invoice, err := fetchInvoice(v, invoices, nil, false) if err != nil { return err } diff --git a/docs/release-notes/release-notes-0.18.3.md b/docs/release-notes/release-notes-0.18.3.md index b91fda5849..80e4417477 100644 --- a/docs/release-notes/release-notes-0.18.3.md +++ b/docs/release-notes/release-notes-0.18.3.md @@ -81,6 +81,10 @@ blinded path expiry. cause UpdateAddHTLC message with blinding point fields to not be re-forwarded correctly on restart. +* [Fixed](https://github.com/lightningnetwork/lnd/pull/9022) a native SQL + invoice issue where AMP subinvoice HTLCs are sometimes updated incorrectly on + settlement. + # New Features ## Functional Enhancements @@ -278,6 +282,7 @@ that validate `ChannelAnnouncement` messages. # Contributors (Alphabetical Order) +* Alex Akselrod * Andras Banki-Horvath * bitromortac * Bufo diff --git a/go.mod b/go.mod index bbe3c88456..3e5b9e8f27 100644 --- a/go.mod +++ b/go.mod @@ -204,6 +204,9 @@ replace github.com/gogo/protobuf => github.com/gogo/protobuf v1.3.2 // allows us to specify that as an option. replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-display v1.30.0-hex-display +// Temporary replace until the next version of sqldb is taged. +replace github.com/lightningnetwork/lnd/sqldb => ./sqldb + // If you change this please also update .github/pull_request_template.md, // docs/INSTALL.md and GO_IMAGE in lnrpc/gen_protos_docker.sh. go 1.22.6 diff --git a/go.sum b/go.sum index a51e291b8d..7dd8449cd9 100644 --- a/go.sum +++ b/go.sum @@ -458,8 +458,6 @@ github.com/lightningnetwork/lnd/kvdb v1.4.10 h1:vK89IVv1oVH9ubQWU+EmoCQFeVRaC8kf github.com/lightningnetwork/lnd/kvdb v1.4.10/go.mod h1:J2diNABOoII9UrMnxXS5w7vZwP7CA1CStrl8MnIrb3A= github.com/lightningnetwork/lnd/queue v1.1.1 h1:99ovBlpM9B0FRCGYJo6RSFDlt8/vOkQQZznVb18iNMI= github.com/lightningnetwork/lnd/queue v1.1.1/go.mod h1:7A6nC1Qrm32FHuhx/mi1cieAiBZo5O6l8IBIoQxvkz4= -github.com/lightningnetwork/lnd/sqldb v1.0.3 h1:zLfAwOvM+6+3+hahYO9Q3h8pVV0TghAR7iJ5YMLCd3I= -github.com/lightningnetwork/lnd/sqldb v1.0.3/go.mod h1:4cQOkdymlZ1znnjuRNvMoatQGJkRneTj2CoPSPaQhWo= github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM= github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA= github.com/lightningnetwork/lnd/tlv v1.2.6 h1:icvQG2yDr6k3ZuZzfRdG3EJp6pHurcuh3R6dg0gv/Mw= diff --git a/invoices/sql_store.go b/invoices/sql_store.go index 4b488715ba..eb465eabb4 100644 --- a/invoices/sql_store.go +++ b/invoices/sql_store.go @@ -10,6 +10,7 @@ import ( "strconv" "time" + "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" @@ -46,6 +47,9 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat GetInvoice(ctx context.Context, arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error) + GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice, + error) + GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]sqlc.InvoiceFeature, error) @@ -343,7 +347,22 @@ func (i *SQLStore) fetchInvoice(ctx context.Context, params.SetID = ref.SetID()[:] } - rows, err := db.GetInvoice(ctx, params) + var ( + rows []sqlc.Invoice + err error + ) + + // We need to split the query based on how we intend to look up the + // invoice. If only the set ID is given then we want to have an exact + // match on the set ID. If other fields are given, we want to match on + // those fields and the set ID but with a less strict join condition. + if params.Hash == nil && params.PaymentAddr == nil && + params.SetID != nil { + + rows, err = db.GetInvoiceBySetID(ctx, params.SetID) + } else { + rows, err = db.GetInvoice(ctx, params) + } switch { case len(rows) == 0: return nil, ErrInvoiceNotFound @@ -351,8 +370,8 @@ func (i *SQLStore) fetchInvoice(ctx context.Context, case len(rows) > 1: // In case the reference is ambiguous, meaning it matches more // than one invoice, we'll return an error. - return nil, fmt.Errorf("ambiguous invoice ref: %s", - ref.String()) + return nil, fmt.Errorf("ambiguous invoice ref: %s: %s", + ref.String(), spew.Sdump(rows)) case err != nil: return nil, fmt.Errorf("unable to fetch invoice: %w", err) @@ -906,8 +925,10 @@ func (i *SQLStore) QueryInvoices(ctx context.Context, } if q.CreationDateEnd != 0 { + // We need to add 1 to the end date as we're + // checking less than the end date in SQL. params.CreatedBefore = sqldb.SQLTime( - time.Unix(q.CreationDateEnd, 0).UTC(), + time.Unix(q.CreationDateEnd+1, 0).UTC(), ) } @@ -1116,6 +1137,9 @@ func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte, SetID: setID[:], HtlcID: int64(circuitKey.HtlcID), Preimage: preimage[:], + ChanID: strconv.FormatUint( + circuitKey.ChanID.ToUint64(), 10, + ), }, ) if err != nil { @@ -1280,6 +1304,13 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte, return err } + if settleIndex.Valid { + updatedState := s.invoice.AMPState[setID] + updatedState.SettleIndex = uint64(settleIndex.Int64) + updatedState.SettleDate = s.updateTime.UTC() + s.invoice.AMPState[setID] = updatedState + } + return nil } @@ -1298,13 +1329,24 @@ func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error { // invoice and is therefore atomic. The fields to update are controlled by the // supplied callback. func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef, - _ *SetID, callback InvoiceUpdateCallback) ( + setID *SetID, callback InvoiceUpdateCallback) ( *Invoice, error) { var updatedInvoice *Invoice txOpt := SQLInvoiceQueriesTxOptions{readOnly: false} txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error { + if setID != nil { + // Make sure to use the set ID if this is an AMP update. + var setIDBytes [32]byte + copy(setIDBytes[:], setID[:]) + ref.setID = &setIDBytes + + // If we're updating an AMP invoice, we'll also only + // need to fetch the HTLCs for the given set ID. + ref.refModifier = HtlcSetOnlyModifier + } + invoice, err := i.fetchInvoice(ctx, db, ref) if err != nil { return err diff --git a/itest/lnd_amp_test.go b/itest/lnd_amp_test.go index 23bfd8654d..4b4cfb5a29 100644 --- a/itest/lnd_amp_test.go +++ b/itest/lnd_amp_test.go @@ -260,7 +260,8 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) { invoiceNtfn := ht.ReceiveInvoiceUpdate(invSubscription) // The notification should signal that the invoice is now settled, and - // should also include the set ID, and show the proper amount paid. + // should also include the set ID, show the proper amount paid, and have + // the correct settle index and time. require.True(ht, invoiceNtfn.Settled) require.Equal(ht, lnrpc.Invoice_SETTLED, invoiceNtfn.State) require.Equal(ht, paymentAmt, int(invoiceNtfn.AmtPaidSat)) @@ -270,6 +271,9 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) { firstSetID, _ = hex.DecodeString(setIDStr) require.Equal(ht, lnrpc.InvoiceHTLCState_SETTLED, ampState.State) + require.GreaterOrEqual(ht, ampState.SettleTime, + rpcInvoice.CreationDate) + require.Equal(ht, uint64(1), ampState.SettleIndex) } // Pay the invoice again, we should get another notification that Dave @@ -299,9 +303,9 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) { // return the "projected" sub-invoice for a given setID. require.Equal(ht, 1, len(invoiceNtfn.Htlcs)) - // However the AMP state index should show that there've been two - // repeated payments to this invoice so far. - require.Equal(ht, 2, len(invoiceNtfn.AmpInvoiceState)) + // The AMP state should also be restricted to a single entry for the + // "projected" sub-invoice. + require.Equal(ht, 1, len(invoiceNtfn.AmpInvoiceState)) // Now we'll look up the invoice using the new LookupInvoice2 RPC call // by the set ID of each of the invoices. @@ -360,7 +364,7 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) { // through. backlogInv := ht.ReceiveInvoiceUpdate(invSub2) require.Equal(ht, 1, len(backlogInv.Htlcs)) - require.Equal(ht, 2, len(backlogInv.AmpInvoiceState)) + require.Equal(ht, 1, len(backlogInv.AmpInvoiceState)) require.True(ht, backlogInv.Settled) require.Equal(ht, paymentAmt*2, int(backlogInv.AmtPaidSat)) } diff --git a/sqldb/sqlc/amp_invoices.sql.go b/sqldb/sqlc/amp_invoices.sql.go index 3fcfe4b27b..e47b1c803d 100644 --- a/sqldb/sqlc/amp_invoices.sql.go +++ b/sqldb/sqlc/amp_invoices.sql.go @@ -268,15 +268,16 @@ func (q *Queries) InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubI const updateAMPSubInvoiceHTLCPreimage = `-- name: UpdateAMPSubInvoiceHTLCPreimage :execresult UPDATE amp_sub_invoice_htlcs AS a -SET preimage = $4 +SET preimage = $5 WHERE a.invoice_id = $1 AND a.set_id = $2 AND a.htlc_id = ( - SELECT id FROM invoice_htlcs AS i WHERE i.htlc_id = $3 + SELECT id FROM invoice_htlcs AS i WHERE i.chan_id = $3 AND i.htlc_id = $4 ) ` type UpdateAMPSubInvoiceHTLCPreimageParams struct { InvoiceID int64 SetID []byte + ChanID string HtlcID int64 Preimage []byte } @@ -285,6 +286,7 @@ func (q *Queries) UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context, arg Updat return q.db.ExecContext(ctx, updateAMPSubInvoiceHTLCPreimage, arg.InvoiceID, arg.SetID, + arg.ChanID, arg.HtlcID, arg.Preimage, ) diff --git a/sqldb/sqlc/invoices.sql.go b/sqldb/sqlc/invoices.sql.go index fde02391c6..9e31380abb 100644 --- a/sqldb/sqlc/invoices.sql.go +++ b/sqldb/sqlc/invoices.sql.go @@ -78,7 +78,7 @@ WHERE ( created_at >= $6 OR $6 IS NULL ) AND ( - created_at <= $7 OR + created_at < $7 OR $7 IS NULL ) AND ( CASE @@ -170,21 +170,22 @@ const getInvoice = `-- name: GetInvoice :many SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at FROM invoices i -LEFT JOIN amp_sub_invoices a on i.id = a.invoice_id +LEFT JOIN amp_sub_invoices a +ON i.id = a.invoice_id +AND ( + a.set_id = $1 OR $1 IS NULL +) WHERE ( - i.id = $1 OR - $1 IS NULL -) AND ( - i.hash = $2 OR + i.id = $2 OR $2 IS NULL ) AND ( - i.preimage = $3 OR + i.hash = $3 OR $3 IS NULL ) AND ( - i.payment_addr = $4 OR + i.preimage = $4 OR $4 IS NULL ) AND ( - a.set_id = $5 OR + i.payment_addr = $5 OR $5 IS NULL ) GROUP BY i.id @@ -192,11 +193,11 @@ LIMIT 2 ` type GetInvoiceParams struct { + SetID []byte AddIndex sql.NullInt64 Hash []byte Preimage []byte PaymentAddr []byte - SetID []byte } // This method may return more than one invoice if filter using multiple fields @@ -204,11 +205,11 @@ type GetInvoiceParams struct { // we bubble up an error in those cases. func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error) { rows, err := q.db.QueryContext(ctx, getInvoice, + arg.SetID, arg.AddIndex, arg.Hash, arg.Preimage, arg.PaymentAddr, - arg.SetID, ) if err != nil { return nil, err @@ -250,6 +251,55 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi return items, nil } +const getInvoiceBySetID = `-- name: GetInvoiceBySetID :many +SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at +FROM invoices i +INNER JOIN amp_sub_invoices a +ON i.id = a.invoice_id AND a.set_id = $1 +` + +func (q *Queries) GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error) { + rows, err := q.db.QueryContext(ctx, getInvoiceBySetID, setID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Invoice + for rows.Next() { + var i Invoice + if err := rows.Scan( + &i.ID, + &i.Hash, + &i.Preimage, + &i.SettleIndex, + &i.SettledAt, + &i.Memo, + &i.AmountMsat, + &i.CltvDelta, + &i.Expiry, + &i.PaymentAddr, + &i.PaymentRequest, + &i.PaymentRequestHash, + &i.State, + &i.AmountPaidMsat, + &i.IsAmp, + &i.IsHodl, + &i.IsKeysend, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getInvoiceFeatures = `-- name: GetInvoiceFeatures :many SELECT feature, invoice_id FROM invoice_features diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index d55d8090a7..04b61c7007 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -21,6 +21,7 @@ type Querier interface { // from different invoices. It is the caller's responsibility to ensure that // we bubble up an error in those cases. GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error) + GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error) GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]InvoiceFeature, error) GetInvoiceHTLCCustomRecords(ctx context.Context, invoiceID int64) ([]GetInvoiceHTLCCustomRecordsRow, error) GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error) diff --git a/sqldb/sqlc/queries/amp_invoices.sql b/sqldb/sqlc/queries/amp_invoices.sql index 3b6ee76ac3..1fad75e0da 100644 --- a/sqldb/sqlc/queries/amp_invoices.sql +++ b/sqldb/sqlc/queries/amp_invoices.sql @@ -61,7 +61,7 @@ WHERE ( -- name: UpdateAMPSubInvoiceHTLCPreimage :execresult UPDATE amp_sub_invoice_htlcs AS a -SET preimage = $4 +SET preimage = $5 WHERE a.invoice_id = $1 AND a.set_id = $2 AND a.htlc_id = ( - SELECT id FROM invoice_htlcs AS i WHERE i.htlc_id = $3 + SELECT id FROM invoice_htlcs AS i WHERE i.chan_id = $3 AND i.htlc_id = $4 ); diff --git a/sqldb/sqlc/queries/invoices.sql b/sqldb/sqlc/queries/invoices.sql index 07c5ca418b..2a49553e65 100644 --- a/sqldb/sqlc/queries/invoices.sql +++ b/sqldb/sqlc/queries/invoices.sql @@ -26,7 +26,11 @@ WHERE invoice_id = $1; -- name: GetInvoice :many SELECT i.* FROM invoices i -LEFT JOIN amp_sub_invoices a on i.id = a.invoice_id +LEFT JOIN amp_sub_invoices a +ON i.id = a.invoice_id +AND ( + a.set_id = sqlc.narg('set_id') OR sqlc.narg('set_id') IS NULL +) WHERE ( i.id = sqlc.narg('add_index') OR sqlc.narg('add_index') IS NULL @@ -39,13 +43,16 @@ WHERE ( ) AND ( i.payment_addr = sqlc.narg('payment_addr') OR sqlc.narg('payment_addr') IS NULL -) AND ( - a.set_id = sqlc.narg('set_id') OR - sqlc.narg('set_id') IS NULL ) GROUP BY i.id LIMIT 2; +-- name: GetInvoiceBySetID :many +SELECT i.* +FROM invoices i +INNER JOIN amp_sub_invoices a +ON i.id = a.invoice_id AND a.set_id = $1; + -- name: FilterInvoices :many SELECT invoices.* @@ -69,7 +76,7 @@ WHERE ( created_at >= sqlc.narg('created_after') OR sqlc.narg('created_after') IS NULL ) AND ( - created_at <= sqlc.narg('created_before') OR + created_at < sqlc.narg('created_before') OR sqlc.narg('created_before') IS NULL ) AND ( CASE