Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

invoices/sqldb: query by ChanID when updating AMP invoice preimage #9022

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ linters-settings:
- 'errors.Wrap'

gomoddirectives:
replace-local: true
replace-allow-list:
# See go.mod for the explanation why these are needed.
- github.com/ulikunitz/xz
- github.com/gogo/protobuf
- google.golang.org/protobuf
- github.com/lightningnetwork/lnd/sqldb


linters:
Expand Down
59 changes: 50 additions & 9 deletions channeldb/invoices.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,9 @@ func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (

// For each key found, we'll look up the actual
// invoice, then accumulate it into our return value.
invoice, err := fetchInvoice(invoiceKey, invoices)
invoice, err := fetchInvoice(
invoiceKey, invoices, nil, false,
)
if err != nil {
return err
}
Expand Down Expand Up @@ -341,7 +343,9 @@ func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (

// An invoice was found, retrieve the remainder of the invoice
// body.
i, err := fetchInvoice(invoiceNum, invoices, setID)
i, err := fetchInvoice(
invoiceNum, invoices, []*invpkg.SetID{setID}, true,
)
if err != nil {
return err
}
Expand Down Expand Up @@ -468,7 +472,7 @@ func (d *DB) FetchPendingInvoices(_ context.Context) (
return nil
}

invoice, err := fetchInvoice(v, invoices)
invoice, err := fetchInvoice(v, invoices, nil, false)
if err != nil {
return err
}
Expand Down Expand Up @@ -526,7 +530,9 @@ func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
// characteristics for our query and returns the number of items
// we have added to our set of invoices.
accumulateInvoices := func(_, indexValue []byte) (bool, error) {
invoice, err := fetchInvoice(indexValue, invoices)
invoice, err := fetchInvoice(
indexValue, invoices, nil, false,
)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -654,7 +660,9 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
if setIDHint != nil {
invSetID = *setIDHint
}
invoice, err := fetchInvoice(invoiceNum, invoices, &invSetID)
invoice, err := fetchInvoice(
invoiceNum, invoices, []*invpkg.SetID{&invSetID}, false,
)
if err != nil {
return err
}
Expand All @@ -676,15 +684,43 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
updatedInvoice, err = invpkg.UpdateInvoice(
payHash, updater.invoice, now, callback, updater,
)
if err != nil {
return err
}

return err
// If this is an AMP update, then limit the returned AMP state
// to only the requested set ID.
if setIDHint != nil {
filterInvoiceAMPState(updatedInvoice, &invSetID)
}

return nil
}, func() {
updatedInvoice = nil
})

return updatedInvoice, err
}

// filterInvoiceAMPState filters the AMP state of the invoice to only include
// state for the specified set IDs.
func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) {
filteredAMPState := make(invpkg.AMPInvoiceState)

for _, setID := range setIDs {
if setID == nil {
return
}

ampState, ok := invoice.AMPState[*setID]
if ok {
filteredAMPState[*setID] = ampState
}
}

invoice.AMPState = filteredAMPState
}

// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC

Expand Down Expand Up @@ -1056,7 +1092,8 @@ func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
// For each key found, we'll look up the actual
// invoice, then accumulate it into our return value.
invoice, err := fetchInvoice(
invoiceKey[:], invoices, setID,
invoiceKey[:], invoices, []*invpkg.SetID{setID},
true,
)
if err != nil {
return err
Expand Down Expand Up @@ -1485,7 +1522,7 @@ func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
// specified by the invoice number. If the setID fields are set, then only the
// HTLC information pertaining to those set IDs is returned.
func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
setIDs ...*invpkg.SetID) (invpkg.Invoice, error) {
setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) {

invoiceBytes := invoices.Get(invoiceNum)
if invoiceBytes == nil {
Expand Down Expand Up @@ -1518,6 +1555,10 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
log.Errorf("unable to fetch amp htlcs for inv "+
"%v and setIDs %v: %w", invoiceNum, setIDs, err)
}

if filterAMPState {
filterInvoiceAMPState(&invoice, setIDs...)
}
}

return invoice, nil
Expand Down Expand Up @@ -2163,7 +2204,7 @@ func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
return nil
}

invoice, err := fetchInvoice(v, invoices)
invoice, err := fetchInvoice(v, invoices, nil, false)
if err != nil {
return err
}
Expand Down
5 changes: 5 additions & 0 deletions docs/release-notes/release-notes-0.18.3.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ blinded path expiry.
cause UpdateAddHTLC message with blinding point fields to not be re-forwarded
correctly on restart.

* [Fixed](https://github.com/lightningnetwork/lnd/pull/9022) a native SQL
invoice issue where AMP subinvoice HTLCs are sometimes updated incorrectly on
settlement.

# New Features
## Functional Enhancements

Expand Down Expand Up @@ -278,6 +282,7 @@ that validate `ChannelAnnouncement` messages.

# Contributors (Alphabetical Order)

* Alex Akselrod
* Andras Banki-Horvath
* bitromortac
* Bufo
Expand Down
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ replace github.com/gogo/protobuf => github.com/gogo/protobuf v1.3.2
// allows us to specify that as an option.
replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-display v1.30.0-hex-display

// Temporary replace until the next version of sqldb is taged.
replace github.com/lightningnetwork/lnd/sqldb => ./sqldb

// If you change this please also update .github/pull_request_template.md,
// docs/INSTALL.md and GO_IMAGE in lnrpc/gen_protos_docker.sh.
go 1.22.6
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,6 @@ github.com/lightningnetwork/lnd/kvdb v1.4.10 h1:vK89IVv1oVH9ubQWU+EmoCQFeVRaC8kf
github.com/lightningnetwork/lnd/kvdb v1.4.10/go.mod h1:J2diNABOoII9UrMnxXS5w7vZwP7CA1CStrl8MnIrb3A=
github.com/lightningnetwork/lnd/queue v1.1.1 h1:99ovBlpM9B0FRCGYJo6RSFDlt8/vOkQQZznVb18iNMI=
github.com/lightningnetwork/lnd/queue v1.1.1/go.mod h1:7A6nC1Qrm32FHuhx/mi1cieAiBZo5O6l8IBIoQxvkz4=
github.com/lightningnetwork/lnd/sqldb v1.0.3 h1:zLfAwOvM+6+3+hahYO9Q3h8pVV0TghAR7iJ5YMLCd3I=
github.com/lightningnetwork/lnd/sqldb v1.0.3/go.mod h1:4cQOkdymlZ1znnjuRNvMoatQGJkRneTj2CoPSPaQhWo=
github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM=
github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA=
github.com/lightningnetwork/lnd/tlv v1.2.6 h1:icvQG2yDr6k3ZuZzfRdG3EJp6pHurcuh3R6dg0gv/Mw=
Expand Down
52 changes: 47 additions & 5 deletions invoices/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strconv"
"time"

"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
Expand Down Expand Up @@ -46,6 +47,9 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
GetInvoice(ctx context.Context,
arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)

GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
error)

GetInvoiceFeatures(ctx context.Context,
invoiceID int64) ([]sqlc.InvoiceFeature, error)

Expand Down Expand Up @@ -343,16 +347,31 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
params.SetID = ref.SetID()[:]
}

rows, err := db.GetInvoice(ctx, params)
var (
rows []sqlc.Invoice
err error
)

// We need to split the query based on how we intend to look up the
// invoice. If only the set ID is given then we want to have an exact
// match on the set ID. If other fields are given, we want to match on
// those fields and the set ID but with a less strict join condition.
if params.Hash == nil && params.PaymentAddr == nil &&
params.SetID != nil {

rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
} else {
rows, err = db.GetInvoice(ctx, params)
}
switch {
case len(rows) == 0:
return nil, ErrInvoiceNotFound

case len(rows) > 1:
// In case the reference is ambiguous, meaning it matches more
// than one invoice, we'll return an error.
return nil, fmt.Errorf("ambiguous invoice ref: %s",
ref.String())
return nil, fmt.Errorf("ambiguous invoice ref: %s: %s",
ref.String(), spew.Sdump(rows))

case err != nil:
return nil, fmt.Errorf("unable to fetch invoice: %w", err)
Expand Down Expand Up @@ -906,8 +925,10 @@ func (i *SQLStore) QueryInvoices(ctx context.Context,
}

if q.CreationDateEnd != 0 {
// We need to add 1 to the end date as we're
// checking less than the end date in SQL.
params.CreatedBefore = sqldb.SQLTime(
time.Unix(q.CreationDateEnd, 0).UTC(),
time.Unix(q.CreationDateEnd+1, 0).UTC(),
)
}

Expand Down Expand Up @@ -1116,6 +1137,9 @@ func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
SetID: setID[:],
HtlcID: int64(circuitKey.HtlcID),
Preimage: preimage[:],
ChanID: strconv.FormatUint(
guggero marked this conversation as resolved.
Show resolved Hide resolved
circuitKey.ChanID.ToUint64(), 10,
),
},
)
if err != nil {
Expand Down Expand Up @@ -1280,6 +1304,13 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
return err
}

if settleIndex.Valid {
updatedState := s.invoice.AMPState[setID]
updatedState.SettleIndex = uint64(settleIndex.Int64)
updatedState.SettleDate = s.updateTime.UTC()
s.invoice.AMPState[setID] = updatedState
}

return nil
}

Expand All @@ -1298,13 +1329,24 @@ func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
// invoice and is therefore atomic. The fields to update are controlled by the
// supplied callback.
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
_ *SetID, callback InvoiceUpdateCallback) (
setID *SetID, callback InvoiceUpdateCallback) (
*Invoice, error) {

var updatedInvoice *Invoice

txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
if setID != nil {
// Make sure to use the set ID if this is an AMP update.
var setIDBytes [32]byte
copy(setIDBytes[:], setID[:])
ref.setID = &setIDBytes

// If we're updating an AMP invoice, we'll also only
// need to fetch the HTLCs for the given set ID.
ref.refModifier = HtlcSetOnlyModifier
}

invoice, err := i.fetchInvoice(ctx, db, ref)
if err != nil {
return err
Expand Down
14 changes: 9 additions & 5 deletions itest/lnd_amp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
invoiceNtfn := ht.ReceiveInvoiceUpdate(invSubscription)

// The notification should signal that the invoice is now settled, and
// should also include the set ID, and show the proper amount paid.
// should also include the set ID, show the proper amount paid, and have
// the correct settle index and time.
require.True(ht, invoiceNtfn.Settled)
require.Equal(ht, lnrpc.Invoice_SETTLED, invoiceNtfn.State)
require.Equal(ht, paymentAmt, int(invoiceNtfn.AmtPaidSat))
Expand All @@ -270,6 +271,9 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
firstSetID, _ = hex.DecodeString(setIDStr)
require.Equal(ht, lnrpc.InvoiceHTLCState_SETTLED,
ampState.State)
require.GreaterOrEqual(ht, ampState.SettleTime,
rpcInvoice.CreationDate)
require.Equal(ht, uint64(1), ampState.SettleIndex)
}

// Pay the invoice again, we should get another notification that Dave
Expand Down Expand Up @@ -299,9 +303,9 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
// return the "projected" sub-invoice for a given setID.
require.Equal(ht, 1, len(invoiceNtfn.Htlcs))

// However the AMP state index should show that there've been two
// repeated payments to this invoice so far.
require.Equal(ht, 2, len(invoiceNtfn.AmpInvoiceState))
// The AMP state should also be restricted to a single entry for the
// "projected" sub-invoice.
require.Equal(ht, 1, len(invoiceNtfn.AmpInvoiceState))

// Now we'll look up the invoice using the new LookupInvoice2 RPC call
// by the set ID of each of the invoices.
Expand Down Expand Up @@ -360,7 +364,7 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
// through.
backlogInv := ht.ReceiveInvoiceUpdate(invSub2)
require.Equal(ht, 1, len(backlogInv.Htlcs))
require.Equal(ht, 2, len(backlogInv.AmpInvoiceState))
require.Equal(ht, 1, len(backlogInv.AmpInvoiceState))
require.True(ht, backlogInv.Settled)
require.Equal(ht, paymentAmt*2, int(backlogInv.AmtPaidSat))
}
Expand Down
6 changes: 4 additions & 2 deletions sqldb/sqlc/amp_invoices.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading