diff --git a/channeldb/channel.go b/channeldb/channel.go index 046ef8806d..ad02084671 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -25,6 +25,7 @@ import ( "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" "github.com/lightningnetwork/lnd/tlv" @@ -1690,11 +1691,11 @@ func (c *OpenChannel) isBorked(chanBucket kvdb.RBucket) (bool, error) { // republish this tx at startup to ensure propagation, and we should still // handle the case where a different tx actually hits the chain. func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx, - locallyInitiated bool) error { + closer lntypes.ChannelParty) error { return c.markBroadcasted( ChanStatusCommitBroadcasted, forceCloseTxKey, closeTx, - locallyInitiated, + closer, ) } @@ -1706,11 +1707,11 @@ func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx, // ensure propagation, and we should still handle the case where a different tx // actually hits the chain. func (c *OpenChannel) MarkCoopBroadcasted(closeTx *wire.MsgTx, - locallyInitiated bool) error { + closer lntypes.ChannelParty) error { return c.markBroadcasted( ChanStatusCoopBroadcasted, coopCloseTxKey, closeTx, - locallyInitiated, + closer, ) } @@ -1719,7 +1720,7 @@ func (c *OpenChannel) MarkCoopBroadcasted(closeTx *wire.MsgTx, // which should specify either a coop or force close. It adds a status which // indicates the party that initiated the channel close. func (c *OpenChannel) markBroadcasted(status ChannelStatus, key []byte, - closeTx *wire.MsgTx, locallyInitiated bool) error { + closeTx *wire.MsgTx, closer lntypes.ChannelParty) error { c.Lock() defer c.Unlock() @@ -1741,7 +1742,7 @@ func (c *OpenChannel) markBroadcasted(status ChannelStatus, key []byte, // Add the initiator status to the status provided. These statuses are // set in addition to the broadcast status so that we do not need to // migrate the original logic which does not store initiator. - if locallyInitiated { + if closer.IsLocal() { status |= ChanStatusLocalCloseInitiator } else { status |= ChanStatusRemoteCloseInitiator @@ -4486,6 +4487,15 @@ func NewShutdownInfo(deliveryScript lnwire.DeliveryAddress, } } +// Closer identifies the ChannelParty that initiated the coop-closure process. +func (s ShutdownInfo) Closer() lntypes.ChannelParty { + if s.LocalInitiator.Val { + return lntypes.Local + } + + return lntypes.Remote +} + // encode serialises the ShutdownInfo to the given io.Writer. func (s *ShutdownInfo) encode(w io.Writer) error { records := []tlv.Record{ diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 981ddf688b..e630b1c48c 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -21,6 +21,7 @@ import ( "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnmock" "github.com/lightningnetwork/lnd/lntest/channels" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" "github.com/lightningnetwork/lnd/tlv" @@ -1084,13 +1085,17 @@ func TestFetchWaitingCloseChannels(t *testing.T) { }, ) - if err := channel.MarkCommitmentBroadcasted(closeTx, true); err != nil { + if err := channel.MarkCommitmentBroadcasted( + closeTx, lntypes.Local, + ); err != nil { t.Fatalf("unable to mark commitment broadcast: %v", err) } // Now try to marking a coop close with a nil tx. This should // succeed, but it shouldn't exit when queried. - if err = channel.MarkCoopBroadcasted(nil, true); err != nil { + if err = channel.MarkCoopBroadcasted( + nil, lntypes.Local, + ); err != nil { t.Fatalf("unable to mark nil coop broadcast: %v", err) } _, err := channel.BroadcastedCooperative() @@ -1102,7 +1107,9 @@ func TestFetchWaitingCloseChannels(t *testing.T) { // it as coop closed. Later we will test that distinct // transactions are returned for both coop and force closes. closeTx.TxIn[0].PreviousOutPoint.Index ^= 1 - if err := channel.MarkCoopBroadcasted(closeTx, true); err != nil { + if err := channel.MarkCoopBroadcasted( + closeTx, lntypes.Local, + ); err != nil { t.Fatalf("unable to mark coop broadcast: %v", err) } } @@ -1324,7 +1331,7 @@ func TestCloseInitiator(t *testing.T) { // by the local party. updateChannel: func(c *OpenChannel) error { return c.MarkCoopBroadcasted( - &wire.MsgTx{}, true, + &wire.MsgTx{}, lntypes.Local, ) }, expectedStatuses: []ChannelStatus{ @@ -1338,7 +1345,7 @@ func TestCloseInitiator(t *testing.T) { // by the remote party. updateChannel: func(c *OpenChannel) error { return c.MarkCoopBroadcasted( - &wire.MsgTx{}, false, + &wire.MsgTx{}, lntypes.Remote, ) }, expectedStatuses: []ChannelStatus{ @@ -1352,7 +1359,7 @@ func TestCloseInitiator(t *testing.T) { // local initiator. updateChannel: func(c *OpenChannel) error { return c.MarkCommitmentBroadcasted( - &wire.MsgTx{}, true, + &wire.MsgTx{}, lntypes.Local, ) }, expectedStatuses: []ChannelStatus{ diff --git a/channeldb/db_test.go b/channeldb/db_test.go index a954f28284..025bf12616 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" "github.com/stretchr/testify/require" @@ -606,7 +607,9 @@ func TestFetchChannels(t *testing.T) { channelIDOption(pendingWaitingChan), ) - err = pendingClosing.MarkCoopBroadcasted(nil, true) + err = pendingClosing.MarkCoopBroadcasted( + nil, lntypes.Local, + ) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -626,7 +629,9 @@ func TestFetchChannels(t *testing.T) { channelIDOption(openWaitingChan), openChannelOption(), ) - err = openClosing.MarkCoopBroadcasted(nil, true) + err = openClosing.MarkCoopBroadcasted( + nil, lntypes.Local, + ) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index 36f6dad18b..abaca5c2ba 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -11,6 +11,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntest/mock" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/stretchr/testify/require" ) @@ -61,12 +62,14 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { for i := 0; i < numChans/2; i++ { closeTx := channels[i].FundingTxn.Copy() closeTx.TxIn[0].PreviousOutPoint = channels[i].FundingOutpoint - err := channels[i].MarkCommitmentBroadcasted(closeTx, true) + err := channels[i].MarkCommitmentBroadcasted( + closeTx, lntypes.Local, + ) if err != nil { t.Fatal(err) } - err = channels[i].MarkCoopBroadcasted(closeTx, true) + err = channels[i].MarkCoopBroadcasted(closeTx, lntypes.Local) if err != nil { t.Fatal(err) } diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 962a239e78..3cbc7422de 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -20,6 +20,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -418,7 +419,7 @@ func (c *chainWatcher) handleUnknownLocalState( // and remote keys for this state. We use our point as only we can // revoke our own commitment. commitKeyRing := lnwallet.DeriveCommitmentKeys( - commitPoint, true, c.cfg.chanState.ChanType, + commitPoint, lntypes.Local, c.cfg.chanState.ChanType, &c.cfg.chanState.LocalChanCfg, &c.cfg.chanState.RemoteChanCfg, ) @@ -891,7 +892,7 @@ func (c *chainWatcher) handlePossibleBreach(commitSpend *chainntnfs.SpendDetail, // Create an AnchorResolution for the breached state. anchorRes, err := lnwallet.NewAnchorResolution( c.cfg.chanState, commitSpend.SpendingTx, retribution.KeyRing, - false, + lntypes.Remote, ) if err != nil { return false, fmt.Errorf("unable to create anchor "+ diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 8add61ce6c..cb5cee8720 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -129,7 +129,7 @@ type ChannelArbitratorConfig struct { // MarkCommitmentBroadcasted should mark the channel as the commitment // being broadcast, and we are waiting for the commitment to confirm. - MarkCommitmentBroadcasted func(*wire.MsgTx, bool) error + MarkCommitmentBroadcasted func(*wire.MsgTx, lntypes.ChannelParty) error // MarkChannelClosed marks the channel closed in the database, with the // passed close summary. After this method successfully returns we can @@ -1084,7 +1084,7 @@ func (c *ChannelArbitrator) stateStep( // database, such that we can re-publish later in case it // didn't propagate. We initiated the force close, so we // mark broadcast with local initiator set to true. - err = c.cfg.MarkCommitmentBroadcasted(closeTx, true) + err = c.cfg.MarkCommitmentBroadcasted(closeTx, lntypes.Local) if err != nil { log.Errorf("ChannelArbitrator(%v): unable to "+ "mark commitment broadcasted: %v", diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 43238494ef..916cd5f580 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -416,7 +416,9 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog, resolvedChan <- struct{}{} return nil }, - MarkCommitmentBroadcasted: func(_ *wire.MsgTx, _ bool) error { + MarkCommitmentBroadcasted: func(_ *wire.MsgTx, + _ lntypes.ChannelParty) error { + return nil }, MarkChannelClosed: func(*channeldb.ChannelCloseSummary, diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index a55cd5d0b2..1311373a17 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -63,7 +63,7 @@ type dustHandler interface { // getDustSum returns the dust sum on either the local or remote // commitment. An optional fee parameter can be passed in which is used // to calculate the dust sum. - getDustSum(remote bool, + getDustSum(whoseCommit lntypes.ChannelParty, fee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi // getFeeRate returns the current channel feerate. diff --git a/htlcswitch/link.go b/htlcswitch/link.go index eee18ff598..eaeaf2e87c 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2727,10 +2727,10 @@ func (l *channelLink) MayAddOutgoingHtlc(amt lnwire.MilliSatoshi) error { // method. // // NOTE: Part of the dustHandler interface. -func (l *channelLink) getDustSum(remote bool, +func (l *channelLink) getDustSum(whoseCommit lntypes.ChannelParty, dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi { - return l.channel.GetDustSum(remote, dryRunFee) + return l.channel.GetDustSum(whoseCommit, dryRunFee) } // getFeeRate is a wrapper method that retrieves the underlying channel's @@ -2784,8 +2784,8 @@ func (l *channelLink) exceedsFeeExposureLimit( // Get the sum of dust for both the local and remote commitments using // this "dry-run" fee. - localDustSum := l.getDustSum(false, dryRunFee) - remoteDustSum := l.getDustSum(true, dryRunFee) + localDustSum := l.getDustSum(lntypes.Local, dryRunFee) + remoteDustSum := l.getDustSum(lntypes.Remote, dryRunFee) // Calculate the local and remote commitment fees using this dry-run // fee. @@ -2826,12 +2826,16 @@ func (l *channelLink) isOverexposedWithHtlc(htlc *lnwire.UpdateAddHTLC, amount := htlc.Amount.ToSatoshis() // See if this HTLC is dust on both the local and remote commitments. - isLocalDust := dustClosure(feeRate, incoming, true, amount) - isRemoteDust := dustClosure(feeRate, incoming, false, amount) + isLocalDust := dustClosure(feeRate, incoming, lntypes.Local, amount) + isRemoteDust := dustClosure(feeRate, incoming, lntypes.Remote, amount) // Calculate the dust sum for the local and remote commitments. - localDustSum := l.getDustSum(false, fn.None[chainfee.SatPerKWeight]()) - remoteDustSum := l.getDustSum(true, fn.None[chainfee.SatPerKWeight]()) + localDustSum := l.getDustSum( + lntypes.Local, fn.None[chainfee.SatPerKWeight](), + ) + remoteDustSum := l.getDustSum( + lntypes.Remote, fn.None[chainfee.SatPerKWeight](), + ) // Grab the larger of the local and remote commitment fees w/o dust. commitFee := l.getCommitFee(false) @@ -2882,25 +2886,26 @@ func (l *channelLink) isOverexposedWithHtlc(htlc *lnwire.UpdateAddHTLC, // the HTLC is incoming (i.e. one that the remote sent), a boolean denoting // whether to evaluate on the local or remote commit, and finally an HTLC // amount to test. -type dustClosure func(chainfee.SatPerKWeight, bool, bool, btcutil.Amount) bool +type dustClosure func(feerate chainfee.SatPerKWeight, incoming bool, + whoseCommit lntypes.ChannelParty, amt btcutil.Amount) bool // dustHelper is used to construct the dustClosure. func dustHelper(chantype channeldb.ChannelType, localDustLimit, remoteDustLimit btcutil.Amount) dustClosure { - isDust := func(feerate chainfee.SatPerKWeight, incoming, - localCommit bool, amt btcutil.Amount) bool { + isDust := func(feerate chainfee.SatPerKWeight, incoming bool, + whoseCommit lntypes.ChannelParty, amt btcutil.Amount) bool { - if localCommit { - return lnwallet.HtlcIsDust( - chantype, incoming, true, feerate, amt, - localDustLimit, - ) + var dustLimit btcutil.Amount + if whoseCommit.IsLocal() { + dustLimit = localDustLimit + } else { + dustLimit = remoteDustLimit } return lnwallet.HtlcIsDust( - chantype, incoming, false, feerate, amt, - remoteDustLimit, + chantype, incoming, whoseCommit, feerate, amt, + dustLimit, ) } diff --git a/htlcswitch/mailbox.go b/htlcswitch/mailbox.go index a729e3ba50..9b82f8912e 100644 --- a/htlcswitch/mailbox.go +++ b/htlcswitch/mailbox.go @@ -9,6 +9,7 @@ import ( "time" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" ) @@ -660,7 +661,8 @@ func (m *memoryMailBox) DustPackets() (lnwire.MilliSatoshi, // Evaluate whether this HTLC is dust on the local commitment. if m.isDust( - m.feeRate, false, true, addPkt.amount.ToSatoshis(), + m.feeRate, false, lntypes.Local, + addPkt.amount.ToSatoshis(), ) { localDustSum += addPkt.amount @@ -668,7 +670,8 @@ func (m *memoryMailBox) DustPackets() (lnwire.MilliSatoshi, // Evaluate whether this HTLC is dust on the remote commitment. if m.isDust( - m.feeRate, false, false, addPkt.amount.ToSatoshis(), + m.feeRate, false, lntypes.Remote, + addPkt.amount.ToSatoshis(), ) { remoteDustSum += addPkt.amount diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 07efd28a03..96417d9c05 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -814,7 +814,7 @@ func (f *mockChannelLink) handleSwitchPacket(pkt *htlcPacket) error { return nil } -func (f *mockChannelLink) getDustSum(remote bool, +func (f *mockChannelLink) getDustSum(whoseCommit lntypes.ChannelParty, dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi { return 0 diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index bfca92a3a0..793da57dbe 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -2788,8 +2788,12 @@ func (s *Switch) dustExceedsFeeThreshold(link ChannelLink, isDust := link.getDustClosure() // Evaluate if the HTLC is dust on either sides' commitment. - isLocalDust := isDust(feeRate, incoming, true, amount.ToSatoshis()) - isRemoteDust := isDust(feeRate, incoming, false, amount.ToSatoshis()) + isLocalDust := isDust( + feeRate, incoming, lntypes.Local, amount.ToSatoshis(), + ) + isRemoteDust := isDust( + feeRate, incoming, lntypes.Remote, amount.ToSatoshis(), + ) if !(isLocalDust || isRemoteDust) { // If the HTLC is not dust on either commitment, it's fine to @@ -2807,7 +2811,7 @@ func (s *Switch) dustExceedsFeeThreshold(link ChannelLink, // sum for it. if isLocalDust { localSum := link.getDustSum( - false, fn.None[chainfee.SatPerKWeight](), + lntypes.Local, fn.None[chainfee.SatPerKWeight](), ) localSum += localMailDust @@ -2827,7 +2831,7 @@ func (s *Switch) dustExceedsFeeThreshold(link ChannelLink, // reached this point. if isRemoteDust { remoteSum := link.getDustSum( - true, fn.None[chainfee.SatPerKWeight](), + lntypes.Remote, fn.None[chainfee.SatPerKWeight](), ) remoteSum += remoteMailDust diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index ce00cd8781..0bc0df2d46 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -4319,7 +4319,7 @@ func TestSwitchDustForwarding(t *testing.T) { } checkAlmostDust := func(link *channelLink, mbox MailBox, - remote bool) bool { + whoseCommit lntypes.ChannelParty) bool { timeout := time.After(15 * time.Second) pollInterval := 300 * time.Millisecond @@ -4335,12 +4335,12 @@ func TestSwitchDustForwarding(t *testing.T) { } linkDust := link.getDustSum( - remote, fn.None[chainfee.SatPerKWeight](), + whoseCommit, fn.None[chainfee.SatPerKWeight](), ) localMailDust, remoteMailDust := mbox.DustPackets() totalDust := linkDust - if remote { + if whoseCommit.IsRemote() { totalDust += remoteMailDust } else { totalDust += localMailDust @@ -4359,7 +4359,11 @@ func TestSwitchDustForwarding(t *testing.T) { n.firstBobChannelLink.ChanID(), n.firstBobChannelLink.ShortChanID(), ) - require.True(t, checkAlmostDust(n.firstBobChannelLink, bobMbox, false)) + require.True( + t, checkAlmostDust( + n.firstBobChannelLink, bobMbox, lntypes.Local, + ), + ) // Sending one more HTLC should fail. SendHTLC won't error, but the // HTLC should be failed backwards. @@ -4408,7 +4412,9 @@ func TestSwitchDustForwarding(t *testing.T) { aliceBobFirstHop, uint64(bobAttemptID), nondustHtlc, ) require.NoError(t, err) - require.True(t, checkAlmostDust(n.firstBobChannelLink, bobMbox, false)) + require.True(t, checkAlmostDust( + n.firstBobChannelLink, bobMbox, lntypes.Local, + )) // Check that the HTLC failed. bobResultChan, err = n.bobServer.htlcSwitch.GetAttemptResult( @@ -4486,7 +4492,11 @@ func TestSwitchDustForwarding(t *testing.T) { aliceMbox := aliceOrch.GetOrCreateMailBox( n.aliceChannelLink.ChanID(), n.aliceChannelLink.ShortChanID(), ) - require.True(t, checkAlmostDust(n.aliceChannelLink, aliceMbox, true)) + require.True( + t, checkAlmostDust( + n.aliceChannelLink, aliceMbox, lntypes.Remote, + ), + ) err = n.aliceServer.htlcSwitch.SendHTLC( n.aliceChannelLink.ShortChanID(), uint64(aliceAttemptID), diff --git a/input/script_utils.go b/input/script_utils.go index 80997eed4d..104c242510 100644 --- a/input/script_utils.go +++ b/input/script_utils.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "golang.org/x/crypto/ripemd160" ) @@ -789,10 +790,10 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, // unilaterally spend the created output. func SenderHTLCScriptTaproot(senderHtlcKey, receiverHtlcKey, revokeKey *btcec.PublicKey, payHash []byte, - localCommit bool) (*HtlcScriptTree, error) { + whoseCommit lntypes.ChannelParty) (*HtlcScriptTree, error) { var hType htlcType - if localCommit { + if whoseCommit.IsLocal() { hType = htlcLocalOutgoing } else { hType = htlcRemoteIncoming @@ -1348,10 +1349,11 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, // the tap leaf are returned. func ReceiverHTLCScriptTaproot(cltvExpiry uint32, senderHtlcKey, receiverHtlcKey, revocationKey *btcec.PublicKey, - payHash []byte, ourCommit bool) (*HtlcScriptTree, error) { + payHash []byte, whoseCommit lntypes.ChannelParty, +) (*HtlcScriptTree, error) { var hType htlcType - if ourCommit { + if whoseCommit.IsLocal() { hType = htlcLocalIncoming } else { hType = htlcRemoteOutgoing diff --git a/input/size_test.go b/input/size_test.go index 9c3446afb3..daa7053ccf 100644 --- a/input/size_test.go +++ b/input/size_test.go @@ -13,6 +13,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/stretchr/testify/require" ) @@ -1073,7 +1074,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], false, + revokeKey.PubKey(), payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1115,7 +1116,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1157,7 +1158,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1203,7 +1204,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], false, + revokeKey.PubKey(), payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1263,7 +1264,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], false, + revokeKey.PubKey(), payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1309,7 +1310,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1394,7 +1395,8 @@ func genTimeoutTx(t *testing.T, ) if chanType.IsTaproot() { tapscriptTree, err = input.SenderHTLCScriptTaproot( - testPubkey, testPubkey, testPubkey, testHash160, false, + testPubkey, testPubkey, testPubkey, testHash160, + lntypes.Remote, ) require.NoError(t, err) @@ -1463,7 +1465,7 @@ func genSuccessTx(t *testing.T, chanType channeldb.ChannelType) *wire.MsgTx { if chanType.IsTaproot() { tapscriptTree, err = input.ReceiverHTLCScriptTaproot( testCLTVExpiry, testPubkey, testPubkey, testPubkey, - testHash160, false, + testHash160, lntypes.Remote, ) require.NoError(t, err) diff --git a/input/taproot_test.go b/input/taproot_test.go index 801b0fef4d..434be2dfdb 100644 --- a/input/taproot_test.go +++ b/input/taproot_test.go @@ -48,7 +48,7 @@ func newTestSenderHtlcScriptTree(t *testing.T) *testSenderHtlcScriptTree { payHash := preImage.Hash() htlcScriptTree, err := SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -471,7 +471,7 @@ func newTestReceiverHtlcScriptTree(t *testing.T) *testReceiverHtlcScriptTree { payHash := preImage.Hash() htlcScriptTree, err := ReceiverHTLCScriptTaproot( cltvExpiry, senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], false, + revokeKey.PubKey(), payHash[:], lntypes.Remote, ) require.NoError(t, err) diff --git a/lntypes/channel_party.go b/lntypes/channel_party.go new file mode 100644 index 0000000000..be800541bd --- /dev/null +++ b/lntypes/channel_party.go @@ -0,0 +1,52 @@ +package lntypes + +import "fmt" + +// ChannelParty is a type used to have an unambiguous description of which node +// is being referred to. This eliminates the need to describe as "local" or +// "remote" using bool. +type ChannelParty uint8 + +const ( + // Local is a ChannelParty constructor that is used to refer to the + // node that is running. + Local ChannelParty = iota + + // Remote is a ChannelParty constructor that is used to refer to the + // node on the other end of the peer connection. + Remote +) + +// String provides a string representation of ChannelParty (useful for logging). +func (p ChannelParty) String() string { + switch p { + case Local: + return "Local" + case Remote: + return "Remote" + default: + panic(fmt.Sprintf("invalid ChannelParty value: %d", p)) + } +} + +// CounterParty inverts the role of the ChannelParty. +func (p ChannelParty) CounterParty() ChannelParty { + switch p { + case Local: + return Remote + case Remote: + return Local + default: + panic(fmt.Sprintf("invalid ChannelParty value: %v", p)) + } +} + +// IsLocal returns true if the ChannelParty is Local. +func (p ChannelParty) IsLocal() bool { + return p == Local +} + +// IsRemote returns true if the ChannelParty is Remote. +func (p ChannelParty) IsRemote() bool { + return p == Remote +} diff --git a/lnwallet/chancloser/chancloser.go b/lnwallet/chancloser/chancloser.go index 3f5e730c0c..57033d4b36 100644 --- a/lnwallet/chancloser/chancloser.go +++ b/lnwallet/chancloser/chancloser.go @@ -15,6 +15,7 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -207,8 +208,8 @@ type ChanCloser struct { // settled channel funds to. remoteDeliveryScript []byte - // locallyInitiated is true if we initiated the channel close. - locallyInitiated bool + // closer is ChannelParty who initiated the coop close + closer lntypes.ChannelParty // cachedClosingSigned is a cached copy of a received ClosingSigned that // we use to handle a specific race condition caused by the independent @@ -267,7 +268,8 @@ func (d *SimpleCoopFeeEstimator) EstimateFee(chanType channeldb.ChannelType, // be populated iff, we're the initiator of this closing request. func NewChanCloser(cfg ChanCloseCfg, deliveryScript []byte, idealFeePerKw chainfee.SatPerKWeight, negotiationHeight uint32, - closeReq *htlcswitch.ChanClose, locallyInitiated bool) *ChanCloser { + closeReq *htlcswitch.ChanClose, + closer lntypes.ChannelParty) *ChanCloser { chanPoint := cfg.Channel.ChannelPoint() cid := lnwire.NewChanIDFromOutPoint(chanPoint) @@ -283,7 +285,7 @@ func NewChanCloser(cfg ChanCloseCfg, deliveryScript []byte, priorFeeOffers: make( map[btcutil.Amount]*lnwire.ClosingSigned, ), - locallyInitiated: locallyInitiated, + closer: closer, } } @@ -366,7 +368,7 @@ func (c *ChanCloser) initChanShutdown() (*lnwire.Shutdown, error) { // message we are about to send in order to ensure that if a // re-establish occurs then we will re-send the same Shutdown message. shutdownInfo := channeldb.NewShutdownInfo( - c.localDeliveryScript, c.locallyInitiated, + c.localDeliveryScript, c.closer.IsLocal(), ) err := c.cfg.Channel.MarkShutdownSent(shutdownInfo) if err != nil { @@ -650,7 +652,7 @@ func (c *ChanCloser) BeginNegotiation() (fn.Option[lnwire.ClosingSigned], // externally consistent, and reflect that the channel is being // shutdown by the time the closing request returns. err := c.cfg.Channel.MarkCoopBroadcasted( - nil, c.locallyInitiated, + nil, c.closer, ) if err != nil { return noClosingSigned, err @@ -861,7 +863,7 @@ func (c *ChanCloser) ReceiveClosingSigned( //nolint:funlen // database, such that it can be republished if something goes // wrong. err = c.cfg.Channel.MarkCoopBroadcasted( - closeTx, c.locallyInitiated, + closeTx, c.closer, ) if err != nil { return noClosing, err diff --git a/lnwallet/chancloser/chancloser_test.go b/lnwallet/chancloser/chancloser_test.go index 1956f0d2b0..9a90d0ab2b 100644 --- a/lnwallet/chancloser/chancloser_test.go +++ b/lnwallet/chancloser/chancloser_test.go @@ -16,6 +16,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -150,7 +151,9 @@ func (m *mockChannel) ChannelPoint() wire.OutPoint { return m.chanPoint } -func (m *mockChannel) MarkCoopBroadcasted(*wire.MsgTx, bool) error { +func (m *mockChannel) MarkCoopBroadcasted(*wire.MsgTx, + lntypes.ChannelParty) error { + return nil } @@ -338,7 +341,7 @@ func TestMaxFeeClamp(t *testing.T) { Channel: &channel, MaxFee: test.inputMaxFee, FeeEstimator: &SimpleCoopFeeEstimator{}, - }, nil, test.idealFee, 0, nil, false, + }, nil, test.idealFee, 0, nil, lntypes.Remote, ) // We'll call initFeeBaseline early here since we need @@ -379,7 +382,7 @@ func TestMaxFeeBailOut(t *testing.T) { MaxFee: idealFee * 2, } chanCloser := NewChanCloser( - closeCfg, nil, idealFee, 0, nil, false, + closeCfg, nil, idealFee, 0, nil, lntypes.Remote, ) // We'll now force the channel state into the @@ -503,7 +506,7 @@ func TestTaprootFastClose(t *testing.T) { DisableChannel: func(wire.OutPoint) error { return nil }, - }, nil, idealFee, 0, nil, true, + }, nil, idealFee, 0, nil, lntypes.Local, ) aliceCloser.initFeeBaseline() @@ -520,7 +523,7 @@ func TestTaprootFastClose(t *testing.T) { DisableChannel: func(wire.OutPoint) error { return nil }, - }, nil, idealFee, 0, nil, false, + }, nil, idealFee, 0, nil, lntypes.Remote, ) bobCloser.initFeeBaseline() diff --git a/lnwallet/chancloser/interface.go b/lnwallet/chancloser/interface.go index 40b81efb4d..2e9fa98ae8 100644 --- a/lnwallet/chancloser/interface.go +++ b/lnwallet/chancloser/interface.go @@ -7,6 +7,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" @@ -33,7 +34,7 @@ type Channel interface { //nolint:interfacebloat // MarkCoopBroadcasted persistently marks that the channel close // transaction has been broadcast. - MarkCoopBroadcasted(*wire.MsgTx, bool) error + MarkCoopBroadcasted(*wire.MsgTx, lntypes.ChannelParty) error // MarkShutdownSent persists the given ShutdownInfo. The existence of // the ShutdownInfo represents the fact that the Shutdown message has diff --git a/lnwallet/channel.go b/lnwallet/channel.go index abe8b62766..afe10b950a 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -271,9 +271,9 @@ type commitment struct { // update number of this commitment. height uint64 - // isOurs indicates whether this is the local or remote node's version - // of the commitment. - isOurs bool + // whoseCommit indicates whether this is the local or remote node's + // version of the commitment. + whoseCommit lntypes.ChannelParty // [our|their]MessageIndex are indexes into the HTLC log, up to which // this commitment transaction includes. These indexes allow both sides @@ -352,8 +352,9 @@ type commitment struct { // massed in is to be retained for each output within the commitment // transition. This ensures that we don't assign multiple HTLCs to the same // index within the commitment transaction. -func locateOutputIndex(p *PaymentDescriptor, tx *wire.MsgTx, ourCommit bool, - dups map[PaymentHash][]int32, cltvs []uint32) (int32, error) { +func locateOutputIndex(p *PaymentDescriptor, tx *wire.MsgTx, + whoseCommit lntypes.ChannelParty, dups map[PaymentHash][]int32, + cltvs []uint32) (int32, error) { // Checks to see if element (e) exists in slice (s). contains := func(s []int32, e int32) bool { @@ -370,7 +371,7 @@ func locateOutputIndex(p *PaymentDescriptor, tx *wire.MsgTx, ourCommit bool, // required as the commitment states are asymmetric in order to ascribe // blame in the case of a contract breach. pkScript := p.theirPkScript - if ourCommit { + if whoseCommit.IsLocal() { pkScript = p.ourPkScript } @@ -418,7 +419,7 @@ func (c *commitment) populateHtlcIndexes(chanType channeldb.ChannelType, // indexes within the commitment view for a particular HTLC. populateIndex := func(htlc *PaymentDescriptor, incoming bool) error { isDust := HtlcIsDust( - chanType, incoming, c.isOurs, c.feePerKw, + chanType, incoming, c.whoseCommit, c.feePerKw, htlc.Amount.ToSatoshis(), c.dustLimit, ) @@ -427,21 +428,21 @@ func (c *commitment) populateHtlcIndexes(chanType channeldb.ChannelType, // If this is our commitment transaction, and this is a dust // output then we mark it as such using a -1 index. - case c.isOurs && isDust: + case c.whoseCommit.IsLocal() && isDust: htlc.localOutputIndex = -1 // If this is the commitment transaction of the remote party, // and this is a dust output then we mark it as such using a -1 // index. - case !c.isOurs && isDust: + case c.whoseCommit.IsRemote() && isDust: htlc.remoteOutputIndex = -1 // If this is our commitment transaction, then we'll need to // locate the output and the index so we can verify an HTLC // signatures. - case c.isOurs: + case c.whoseCommit.IsLocal(): htlc.localOutputIndex, err = locateOutputIndex( - htlc, c.txn, c.isOurs, dups, cltvs, + htlc, c.txn, c.whoseCommit, dups, cltvs, ) if err != nil { return err @@ -460,9 +461,9 @@ func (c *commitment) populateHtlcIndexes(chanType channeldb.ChannelType, // Otherwise, this is there remote party's commitment // transaction and we only need to populate the remote output // index within the HTLC index. - case !c.isOurs: + case c.whoseCommit.IsRemote(): htlc.remoteOutputIndex, err = locateOutputIndex( - htlc, c.txn, c.isOurs, dups, cltvs, + htlc, c.txn, c.whoseCommit, dups, cltvs, ) if err != nil { return err @@ -497,7 +498,9 @@ func (c *commitment) populateHtlcIndexes(chanType channeldb.ChannelType, // toDiskCommit converts the target commitment into a format suitable to be // written to disk after an accepted state transition. -func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { +func (c *commitment) toDiskCommit( + whoseCommit lntypes.ChannelParty) *channeldb.ChannelCommitment { + numHtlcs := len(c.outgoingHTLCs) + len(c.incomingHTLCs) commit := &channeldb.ChannelCommitment{ @@ -517,7 +520,7 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { for _, htlc := range c.outgoingHTLCs { outputIndex := htlc.localOutputIndex - if !ourCommit { + if whoseCommit.IsRemote() { outputIndex = htlc.remoteOutputIndex } @@ -533,7 +536,7 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { } copy(h.OnionBlob[:], htlc.OnionBlob) - if ourCommit && htlc.sig != nil { + if whoseCommit.IsLocal() && htlc.sig != nil { h.Signature = htlc.sig.Serialize() } @@ -542,7 +545,7 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { for _, htlc := range c.incomingHTLCs { outputIndex := htlc.localOutputIndex - if !ourCommit { + if whoseCommit.IsRemote() { outputIndex = htlc.remoteOutputIndex } @@ -557,7 +560,7 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { BlindingPoint: htlc.BlindingPoint, } copy(h.OnionBlob[:], htlc.OnionBlob) - if ourCommit && htlc.sig != nil { + if whoseCommit.IsLocal() && htlc.sig != nil { h.Signature = htlc.sig.Serialize() } @@ -574,8 +577,8 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { // restart a channel session. func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, commitHeight uint64, htlc *channeldb.HTLC, localCommitKeys, - remoteCommitKeys *CommitmentKeyRing, isLocal bool) (PaymentDescriptor, - error) { + remoteCommitKeys *CommitmentKeyRing, whoseCommit lntypes.ChannelParty, +) (PaymentDescriptor, error) { // The proper pkScripts for this PaymentDescriptor must be // generated so we can easily locate them within the commitment @@ -593,13 +596,13 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, // transaction. As we'll mark dust with a special output index in the // on-disk state snapshot. isDustLocal := HtlcIsDust( - chanType, htlc.Incoming, true, feeRate, + chanType, htlc.Incoming, lntypes.Local, feeRate, htlc.Amt.ToSatoshis(), lc.channelState.LocalChanCfg.DustLimit, ) if !isDustLocal && localCommitKeys != nil { scriptInfo, err := genHtlcScript( - chanType, htlc.Incoming, true, htlc.RefundTimeout, - htlc.RHash, localCommitKeys, + chanType, htlc.Incoming, lntypes.Local, + htlc.RefundTimeout, htlc.RHash, localCommitKeys, ) if err != nil { return pd, err @@ -608,13 +611,13 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, ourWitnessScript = scriptInfo.WitnessScriptToSign() } isDustRemote := HtlcIsDust( - chanType, htlc.Incoming, false, feeRate, + chanType, htlc.Incoming, lntypes.Remote, feeRate, htlc.Amt.ToSatoshis(), lc.channelState.RemoteChanCfg.DustLimit, ) if !isDustRemote && remoteCommitKeys != nil { scriptInfo, err := genHtlcScript( - chanType, htlc.Incoming, false, htlc.RefundTimeout, - htlc.RHash, remoteCommitKeys, + chanType, htlc.Incoming, lntypes.Remote, + htlc.RefundTimeout, htlc.RHash, remoteCommitKeys, ) if err != nil { return pd, err @@ -630,7 +633,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, localOutputIndex int32 remoteOutputIndex int32 ) - if isLocal { + if whoseCommit.IsLocal() { localOutputIndex = htlc.OutputIndex } else { remoteOutputIndex = htlc.OutputIndex @@ -663,8 +666,8 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, // for each side. func (lc *LightningChannel) extractPayDescs(commitHeight uint64, feeRate chainfee.SatPerKWeight, htlcs []channeldb.HTLC, localCommitKeys, - remoteCommitKeys *CommitmentKeyRing, isLocal bool) ([]PaymentDescriptor, - []PaymentDescriptor, error) { + remoteCommitKeys *CommitmentKeyRing, whoseCommit lntypes.ChannelParty, +) ([]PaymentDescriptor, []PaymentDescriptor, error) { var ( incomingHtlcs []PaymentDescriptor @@ -684,7 +687,7 @@ func (lc *LightningChannel) extractPayDescs(commitHeight uint64, payDesc, err := lc.diskHtlcToPayDesc( feeRate, commitHeight, &htlc, localCommitKeys, remoteCommitKeys, - isLocal, + whoseCommit, ) if err != nil { return incomingHtlcs, outgoingHtlcs, err @@ -703,7 +706,8 @@ func (lc *LightningChannel) extractPayDescs(commitHeight uint64, // diskCommitToMemCommit converts the on-disk commitment format to our // in-memory commitment format which is needed in order to properly resume // channel operations after a restart. -func (lc *LightningChannel) diskCommitToMemCommit(isLocal bool, +func (lc *LightningChannel) diskCommitToMemCommit( + whoseCommit lntypes.ChannelParty, diskCommit *channeldb.ChannelCommitment, localCommitPoint, remoteCommitPoint *btcec.PublicKey) (*commitment, error) { @@ -715,14 +719,16 @@ func (lc *LightningChannel) diskCommitToMemCommit(isLocal bool, var localCommitKeys, remoteCommitKeys *CommitmentKeyRing if localCommitPoint != nil { localCommitKeys = DeriveCommitmentKeys( - localCommitPoint, true, lc.channelState.ChanType, + localCommitPoint, lntypes.Local, + lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, ) } if remoteCommitPoint != nil { remoteCommitKeys = DeriveCommitmentKeys( - remoteCommitPoint, false, lc.channelState.ChanType, + remoteCommitPoint, lntypes.Remote, + lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, ) @@ -735,7 +741,7 @@ func (lc *LightningChannel) diskCommitToMemCommit(isLocal bool, diskCommit.CommitHeight, chainfee.SatPerKWeight(diskCommit.FeePerKw), diskCommit.Htlcs, localCommitKeys, remoteCommitKeys, - isLocal, + whoseCommit, ) if err != nil { return nil, err @@ -745,7 +751,7 @@ func (lc *LightningChannel) diskCommitToMemCommit(isLocal bool, // commitment state as it was originally present in memory. commit := &commitment{ height: diskCommit.CommitHeight, - isOurs: isLocal, + whoseCommit: whoseCommit, ourBalance: diskCommit.LocalBalance, theirBalance: diskCommit.RemoteBalance, ourMessageIndex: diskCommit.LocalLogIndex, @@ -759,7 +765,7 @@ func (lc *LightningChannel) diskCommitToMemCommit(isLocal bool, incomingHTLCs: incomingHtlcs, outgoingHTLCs: outgoingHtlcs, } - if isLocal { + if whoseCommit.IsLocal() { commit.dustLimit = lc.channelState.LocalChanCfg.DustLimit } else { commit.dustLimit = lc.channelState.RemoteChanCfg.DustLimit @@ -1102,12 +1108,12 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, copy(pd.OnionBlob[:], wireMsg.OnionBlob[:]) isDustRemote := HtlcIsDust( - lc.channelState.ChanType, false, false, feeRate, - wireMsg.Amount.ToSatoshis(), remoteDustLimit, + lc.channelState.ChanType, false, lntypes.Remote, + feeRate, wireMsg.Amount.ToSatoshis(), remoteDustLimit, ) if !isDustRemote { scriptInfo, err := genHtlcScript( - lc.channelState.ChanType, false, false, + lc.channelState.ChanType, false, lntypes.Remote, wireMsg.Expiry, wireMsg.PaymentHash, remoteCommitKeys, ) @@ -1400,7 +1406,7 @@ func (lc *LightningChannel) restoreCommitState( // commitment into our in-memory commitment format, inserting it into // the local commitment chain. localCommit, err := lc.diskCommitToMemCommit( - true, localCommitState, localCommitPoint, + lntypes.Local, localCommitState, localCommitPoint, remoteCommitPoint, ) if err != nil { @@ -1413,7 +1419,7 @@ func (lc *LightningChannel) restoreCommitState( // We'll also do the same for the remote commitment chain. remoteCommit, err := lc.diskCommitToMemCommit( - false, remoteCommitState, localCommitPoint, + lntypes.Remote, remoteCommitState, localCommitPoint, remoteCommitPoint, ) if err != nil { @@ -1445,7 +1451,7 @@ func (lc *LightningChannel) restoreCommitState( // corresponding state for the local commitment chain. pendingCommitPoint := lc.channelState.RemoteNextRevocation pendingRemoteCommit, err = lc.diskCommitToMemCommit( - false, &pendingRemoteCommitDiff.Commitment, + lntypes.Remote, &pendingRemoteCommitDiff.Commitment, nil, pendingCommitPoint, ) if err != nil { @@ -1459,8 +1465,10 @@ func (lc *LightningChannel) restoreCommitState( // We'll also re-create the set of commitment keys needed to // fully re-derive the state. pendingRemoteKeyChain = DeriveCommitmentKeys( - pendingCommitPoint, false, lc.channelState.ChanType, - &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, + pendingCommitPoint, lntypes.Remote, + lc.channelState.ChanType, + &lc.channelState.LocalChanCfg, + &lc.channelState.RemoteChanCfg, ) } @@ -1971,7 +1979,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, // With the commitment point generated, we can now generate the four // keys we'll need to reconstruct the commitment state, keyRing := DeriveCommitmentKeys( - commitmentPoint, false, chanState.ChanType, + commitmentPoint, lntypes.Remote, chanState.ChanType, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) @@ -2174,7 +2182,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // then from the PoV of the remote commitment state, they're the // receiver of this HTLC. scriptInfo, err := genHtlcScript( - chanState.ChanType, htlc.Incoming, false, + chanState.ChanType, htlc.Incoming, lntypes.Remote, htlc.RefundTimeout, htlc.RHash, keyRing, ) if err != nil { @@ -2377,7 +2385,7 @@ func createBreachRetributionLegacy(revokedLog *channeldb.ChannelCommitment, // If the HTLC is dust, then we'll skip it as it doesn't have // an output on the commitment transaction. if HtlcIsDust( - chanState.ChanType, htlc.Incoming, false, + chanState.ChanType, htlc.Incoming, lntypes.Remote, chainfee.SatPerKWeight(revokedLog.FeePerKw), htlc.Amt.ToSatoshis(), chanState.RemoteChanCfg.DustLimit, @@ -2424,8 +2432,9 @@ func createBreachRetributionLegacy(revokedLog *channeldb.ChannelCommitment, // covenants. Depending on the two bits, we'll either be using a timeout or // success transaction which have different weights. func HtlcIsDust(chanType channeldb.ChannelType, - incoming, ourCommit bool, feePerKw chainfee.SatPerKWeight, - htlcAmt, dustLimit btcutil.Amount) bool { + incoming bool, whoseCommit lntypes.ChannelParty, + feePerKw chainfee.SatPerKWeight, htlcAmt, dustLimit btcutil.Amount, +) bool { // First we'll determine the fee required for this HTLC based on if this is // an incoming HTLC or not, and also on whose commitment transaction it @@ -2435,25 +2444,25 @@ func HtlcIsDust(chanType channeldb.ChannelType, // If this is an incoming HTLC on our commitment transaction, then the // second-level transaction will be a success transaction. - case incoming && ourCommit: + case incoming && whoseCommit.IsLocal(): htlcFee = HtlcSuccessFee(chanType, feePerKw) // If this is an incoming HTLC on their commitment transaction, then // we'll be using a second-level timeout transaction as they've added // this HTLC. - case incoming && !ourCommit: + case incoming && whoseCommit.IsRemote(): htlcFee = HtlcTimeoutFee(chanType, feePerKw) // If this is an outgoing HTLC on our commitment transaction, then // we'll be using a timeout transaction as we're the sender of the // HTLC. - case !incoming && ourCommit: + case !incoming && whoseCommit.IsLocal(): htlcFee = HtlcTimeoutFee(chanType, feePerKw) // If this is an outgoing HTLC on their commitment transaction, then // we'll be using an HTLC success transaction as they're the receiver // of this HTLC. - case !incoming && !ourCommit: + case !incoming && whoseCommit.IsRemote(): htlcFee = HtlcSuccessFee(chanType, feePerKw) } @@ -2508,13 +2517,14 @@ func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *ht // both local and remote commitment transactions in order to sign or verify new // commitment updates. A fully populated commitment is returned which reflects // the proper balances for both sides at this point in the commitment chain. -func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, +func (lc *LightningChannel) fetchCommitmentView( + whoseCommitChain lntypes.ChannelParty, ourLogIndex, ourHtlcIndex, theirLogIndex, theirHtlcIndex uint64, keyRing *CommitmentKeyRing) (*commitment, error) { commitChain := lc.localCommitChain dustLimit := lc.channelState.LocalChanCfg.DustLimit - if remoteChain { + if whoseCommitChain.IsRemote() { commitChain = lc.remoteCommitChain dustLimit = lc.channelState.RemoteChanCfg.DustLimit } @@ -2528,7 +2538,8 @@ func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, // initiator. htlcView := lc.fetchHTLCView(theirLogIndex, ourLogIndex) ourBalance, theirBalance, _, filteredHTLCView, err := lc.computeView( - htlcView, remoteChain, true, fn.None[chainfee.SatPerKWeight](), + htlcView, whoseCommitChain, true, + fn.None[chainfee.SatPerKWeight](), ) if err != nil { return nil, err @@ -2537,8 +2548,8 @@ func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, // Actually generate unsigned commitment transaction for this view. commitTx, err := lc.commitBuilder.createUnsignedCommitmentTx( - ourBalance, theirBalance, !remoteChain, feePerKw, nextHeight, - filteredHTLCView, keyRing, + ourBalance, theirBalance, whoseCommitChain, feePerKw, + nextHeight, filteredHTLCView, keyRing, ) if err != nil { return nil, err @@ -2587,7 +2598,7 @@ func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, height: nextHeight, feePerKw: feePerKw, dustLimit: dustLimit, - isOurs: !remoteChain, + whoseCommit: whoseCommitChain, } // In order to ensure _none_ of the HTLC's associated with this new @@ -2635,7 +2646,8 @@ func fundingTxIn(chanState *channeldb.OpenChannel) wire.TxIn { // method. func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, theirBalance *lnwire.MilliSatoshi, nextHeight uint64, - remoteChain, mutateState bool) (*htlcView, error) { + whoseCommitChain lntypes.ChannelParty, mutateState bool, +) (*htlcView, error) { // We initialize the view's fee rate to the fee rate of the unfiltered // view. If any fee updates are found when evaluating the view, it will @@ -2663,8 +2675,8 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, // Process fee updates, updating the current feePerKw. case FeeUpdate: processFeeUpdate( - entry, nextHeight, remoteChain, mutateState, - newView, + entry, nextHeight, whoseCommitChain, + mutateState, newView, ) continue } @@ -2672,19 +2684,22 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, // If we're settling an inbound HTLC, and it hasn't been // processed yet, then increment our state tracking the total // number of satoshis we've received within the channel. - if mutateState && entry.EntryType == Settle && !remoteChain && + if mutateState && entry.EntryType == Settle && + whoseCommitChain.IsLocal() && entry.removeCommitHeightLocal == 0 { lc.channelState.TotalMSatReceived += entry.Amount } - addEntry, err := lc.fetchParent(entry, remoteChain, true) + addEntry, err := lc.fetchParent( + entry, whoseCommitChain, lntypes.Remote, + ) if err != nil { return nil, err } skipThem[addEntry.HtlcIndex] = struct{}{} processRemoveEntry(entry, ourBalance, theirBalance, - nextHeight, remoteChain, true, mutateState) + nextHeight, whoseCommitChain, true, mutateState) } for _, entry := range view.theirUpdates { switch entry.EntryType { @@ -2695,8 +2710,8 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, // Process fee updates, updating the current feePerKw. case FeeUpdate: processFeeUpdate( - entry, nextHeight, remoteChain, mutateState, - newView, + entry, nextHeight, whoseCommitChain, + mutateState, newView, ) continue } @@ -2705,19 +2720,23 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, // and it hasn't been processed, yet, the increment our state // tracking the total number of satoshis we've sent within the // channel. - if mutateState && entry.EntryType == Settle && !remoteChain && + if mutateState && entry.EntryType == Settle && + whoseCommitChain.IsLocal() && entry.removeCommitHeightLocal == 0 { + lc.channelState.TotalMSatSent += entry.Amount } - addEntry, err := lc.fetchParent(entry, remoteChain, false) + addEntry, err := lc.fetchParent( + entry, whoseCommitChain, lntypes.Local, + ) if err != nil { return nil, err } skipUs[addEntry.HtlcIndex] = struct{}{} processRemoveEntry(entry, ourBalance, theirBalance, - nextHeight, remoteChain, false, mutateState) + nextHeight, whoseCommitChain, false, mutateState) } // Next we take a second pass through all the log entries, skipping any @@ -2730,7 +2749,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, } processAddEntry(entry, ourBalance, theirBalance, nextHeight, - remoteChain, false, mutateState) + whoseCommitChain, false, mutateState) newView.ourUpdates = append(newView.ourUpdates, entry) } for _, entry := range view.theirUpdates { @@ -2740,7 +2759,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, } processAddEntry(entry, ourBalance, theirBalance, nextHeight, - remoteChain, true, mutateState) + whoseCommitChain, true, mutateState) newView.theirUpdates = append(newView.theirUpdates, entry) } @@ -2750,14 +2769,15 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, // fetchParent is a helper that looks up update log parent entries in the // appropriate log. func (lc *LightningChannel) fetchParent(entry *PaymentDescriptor, - remoteChain, remoteLog bool) (*PaymentDescriptor, error) { + whoseCommitChain, whoseUpdateLog lntypes.ChannelParty, +) (*PaymentDescriptor, error) { var ( updateLog *updateLog logName string ) - if remoteLog { + if whoseUpdateLog.IsRemote() { updateLog = lc.remoteUpdateLog logName = "remote" } else { @@ -2781,11 +2801,16 @@ func (lc *LightningChannel) fetchParent(entry *PaymentDescriptor, // The parent add height should never be zero at this point. If // that's the case we probably forgot to send a new commitment. - case remoteChain && addEntry.addCommitHeightRemote == 0: + case whoseCommitChain.IsRemote() && + addEntry.addCommitHeightRemote == 0: + return nil, fmt.Errorf("parent entry %d for update %d "+ "had zero remote add height", entry.ParentIndex, entry.LogIndex) - case !remoteChain && addEntry.addCommitHeightLocal == 0: + + case whoseCommitChain.IsLocal() && + addEntry.addCommitHeightLocal == 0: + return nil, fmt.Errorf("parent entry %d for update %d "+ "had zero local add height", entry.ParentIndex, entry.LogIndex) @@ -2798,15 +2823,16 @@ func (lc *LightningChannel) fetchParent(entry *PaymentDescriptor, // If the HTLC hasn't yet been committed in either chain, then the height it // was committed is updated. Keeping track of this inclusion height allows us to // later compact the log once the change is fully committed in both chains. -func processAddEntry(htlc *PaymentDescriptor, ourBalance, theirBalance *lnwire.MilliSatoshi, - nextHeight uint64, remoteChain bool, isIncoming, mutateState bool) { +func processAddEntry(htlc *PaymentDescriptor, ourBalance, + theirBalance *lnwire.MilliSatoshi, nextHeight uint64, + whoseCommitChain lntypes.ChannelParty, isIncoming, mutateState bool) { // If we're evaluating this entry for the remote chain (to create/view // a new commitment), then we'll may be updating the height this entry // was added to the chain. Otherwise, we may be updating the entry's // height w.r.t the local chain. var addHeight *uint64 - if remoteChain { + if whoseCommitChain.IsRemote() { addHeight = &htlc.addCommitHeightRemote } else { addHeight = &htlc.addCommitHeightLocal @@ -2837,10 +2863,10 @@ func processAddEntry(htlc *PaymentDescriptor, ourBalance, theirBalance *lnwire.M // is skipped. func processRemoveEntry(htlc *PaymentDescriptor, ourBalance, theirBalance *lnwire.MilliSatoshi, nextHeight uint64, - remoteChain bool, isIncoming, mutateState bool) { + whoseCommitChain lntypes.ChannelParty, isIncoming, mutateState bool) { var removeHeight *uint64 - if remoteChain { + if whoseCommitChain.IsRemote() { removeHeight = &htlc.removeCommitHeightRemote } else { removeHeight = &htlc.removeCommitHeightLocal @@ -2885,14 +2911,15 @@ func processRemoveEntry(htlc *PaymentDescriptor, ourBalance, // processFeeUpdate processes a log update that updates the current commitment // fee. func processFeeUpdate(feeUpdate *PaymentDescriptor, nextHeight uint64, - remoteChain bool, mutateState bool, view *htlcView) { + whoseCommitChain lntypes.ChannelParty, mutateState bool, view *htlcView, +) { // Fee updates are applied for all commitments after they are // sent/received, so we consider them being added and removed at the // same height. var addHeight *uint64 var removeHeight *uint64 - if remoteChain { + if whoseCommitChain.IsRemote() { addHeight = &feeUpdate.addCommitHeightRemote removeHeight = &feeUpdate.removeCommitHeightRemote } else { @@ -2945,7 +2972,7 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, // sigJob will be generated and appended to the current batch. for _, htlc := range remoteCommitView.incomingHTLCs { if HtlcIsDust( - chanType, true, false, feePerKw, + chanType, true, lntypes.Remote, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -3014,7 +3041,7 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, } for _, htlc := range remoteCommitView.outgoingHTLCs { if HtlcIsDust( - chanType, false, false, feePerKw, + chanType, false, lntypes.Remote, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -3212,7 +3239,7 @@ func (lc *LightningChannel) createCommitDiff( // With the set of log updates mapped into wire messages, we'll now // convert the in-memory commit into a format suitable for writing to // disk. - diskCommit := newCommit.toDiskCommit(false) + diskCommit := newCommit.toDiskCommit(lntypes.Remote) return &channeldb.CommitDiff{ Commitment: *diskCommit, @@ -3463,12 +3490,13 @@ func (lc *LightningChannel) applyCommitFee( // PaymentDescriptor if we are validating in the state when adding a new HTLC, // or nil otherwise. func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, - ourLogCounter uint64, remoteChain bool, buffer BufferType, - predictOurAdd, predictTheirAdd *PaymentDescriptor) error { + ourLogCounter uint64, whoseCommitChain lntypes.ChannelParty, + buffer BufferType, predictOurAdd, predictTheirAdd *PaymentDescriptor, +) error { // First fetch the initial balance before applying any updates. commitChain := lc.localCommitChain - if remoteChain { + if whoseCommitChain.IsRemote() { commitChain = lc.remoteCommitChain } ourInitialBalance := commitChain.tip().ourBalance @@ -3488,7 +3516,8 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, } ourBalance, theirBalance, commitWeight, filteredView, err := lc.computeView( - view, remoteChain, false, fn.None[chainfee.SatPerKWeight](), + view, whoseCommitChain, false, + fn.None[chainfee.SatPerKWeight](), ) if err != nil { return err @@ -3703,7 +3732,7 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { // dare to fail hard here. We assume peers can deal with the empty sig // and continue channel operation. We log an error so that the bug // causing this can be tracked down. - if !lc.oweCommitment(true) { + if !lc.oweCommitment(lntypes.Local) { lc.log.Errorf("sending empty commit sig") } @@ -3737,8 +3766,8 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { // point all updates will have to get locked-in so we enforce the // minimum requirement. err := lc.validateCommitmentSanity( - remoteACKedIndex, lc.localUpdateLog.logIndex, true, NoBuffer, - nil, nil, + remoteACKedIndex, lc.localUpdateLog.logIndex, lntypes.Remote, + NoBuffer, nil, nil, ) if err != nil { return nil, err @@ -3748,7 +3777,7 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { // used within fetchCommitmentView to derive all the keys necessary to // construct the commitment state. keyRing := DeriveCommitmentKeys( - commitPoint, false, lc.channelState.ChanType, + commitPoint, lntypes.Remote, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, ) @@ -3760,8 +3789,9 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { // _all_ of our changes (pending or committed) but only the remote // node's changes up to the last change we've ACK'd. newCommitView, err := lc.fetchCommitmentView( - true, lc.localUpdateLog.logIndex, lc.localUpdateLog.htlcCounter, - remoteACKedIndex, remoteHtlcIndex, keyRing, + lntypes.Remote, lc.localUpdateLog.logIndex, + lc.localUpdateLog.htlcCounter, remoteACKedIndex, + remoteHtlcIndex, keyRing, ) if err != nil { return nil, err @@ -4255,14 +4285,14 @@ func (lc *LightningChannel) ProcessChanSyncMsg( // // If the updateState boolean is set true, the add and remove heights of the // HTLCs will be set to the next commitment height. -func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, - updateState bool, dryRunFee fn.Option[chainfee.SatPerKWeight]) ( - lnwire.MilliSatoshi, lnwire.MilliSatoshi, lntypes.WeightUnit, - *htlcView, error) { +func (lc *LightningChannel) computeView(view *htlcView, + whoseCommitChain lntypes.ChannelParty, updateState bool, + dryRunFee fn.Option[chainfee.SatPerKWeight]) (lnwire.MilliSatoshi, + lnwire.MilliSatoshi, lntypes.WeightUnit, *htlcView, error) { commitChain := lc.localCommitChain dustLimit := lc.channelState.LocalChanCfg.DustLimit - if remoteChain { + if whoseCommitChain.IsRemote() { commitChain = lc.remoteCommitChain dustLimit = lc.channelState.RemoteChanCfg.DustLimit } @@ -4298,7 +4328,7 @@ func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, // updates are found in the logs, the commitment fee rate should be // changed, so we'll also set the feePerKw to this new value. filteredHTLCView, err := lc.evaluateHTLCView(view, &ourBalance, - &theirBalance, nextHeight, remoteChain, updateState) + &theirBalance, nextHeight, whoseCommitChain, updateState) if err != nil { return 0, 0, 0, nil, err } @@ -4328,7 +4358,7 @@ func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, var totalHtlcWeight lntypes.WeightUnit for _, htlc := range filteredHTLCView.ourUpdates { if HtlcIsDust( - lc.channelState.ChanType, false, !remoteChain, + lc.channelState.ChanType, false, whoseCommitChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -4339,7 +4369,7 @@ func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, } for _, htlc := range filteredHTLCView.theirUpdates { if HtlcIsDust( - lc.channelState.ChanType, true, !remoteChain, + lc.channelState.ChanType, true, whoseCommitChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -4681,7 +4711,7 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error { // reliable, because it could be that we've sent out a new sig, but the // remote hasn't received it yet. We could then falsely assume that they // should add our updates to their remote commitment tx. - if !lc.oweCommitment(false) { + if !lc.oweCommitment(lntypes.Remote) { lc.log.Warnf("empty commit sig message received") } @@ -4698,8 +4728,8 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error { // the UpdateAddHTLC msg from our peer prior to receiving the // commit-sig). err := lc.validateCommitmentSanity( - lc.remoteUpdateLog.logIndex, localACKedIndex, false, NoBuffer, - nil, nil, + lc.remoteUpdateLog.logIndex, localACKedIndex, lntypes.Local, + NoBuffer, nil, nil, ) if err != nil { return err @@ -4716,7 +4746,7 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error { } commitPoint := input.ComputeCommitmentPoint(commitSecret[:]) keyRing := DeriveCommitmentKeys( - commitPoint, true, lc.channelState.ChanType, + commitPoint, lntypes.Local, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, ) @@ -4725,7 +4755,7 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error { // we know of in the remote node's HTLC log, but only our local changes // up to the last change the remote node has ACK'd. localCommitmentView, err := lc.fetchCommitmentView( - false, localACKedIndex, localHtlcIndex, + lntypes.Local, localACKedIndex, localHtlcIndex, lc.remoteUpdateLog.logIndex, lc.remoteUpdateLog.htlcCounter, keyRing, ) @@ -4962,11 +4992,11 @@ func (lc *LightningChannel) IsChannelClean() bool { // Now check that both local and remote commitments are signing the // same updates. - if lc.oweCommitment(true) { + if lc.oweCommitment(lntypes.Local) { return false } - if lc.oweCommitment(false) { + if lc.oweCommitment(lntypes.Remote) { return false } @@ -4983,7 +5013,7 @@ func (lc *LightningChannel) OweCommitment() bool { lc.RLock() defer lc.RUnlock() - return lc.oweCommitment(true) + return lc.oweCommitment(lntypes.Local) } // NeedCommitment returns a boolean value reflecting whether we are waiting on @@ -4994,12 +5024,12 @@ func (lc *LightningChannel) NeedCommitment() bool { lc.RLock() defer lc.RUnlock() - return lc.oweCommitment(false) + return lc.oweCommitment(lntypes.Remote) } // oweCommitment is the internal version of OweCommitment. This function expects // to be executed with a lock held. -func (lc *LightningChannel) oweCommitment(local bool) bool { +func (lc *LightningChannel) oweCommitment(issuer lntypes.ChannelParty) bool { var ( remoteUpdatesPending, localUpdatesPending bool @@ -5009,7 +5039,7 @@ func (lc *LightningChannel) oweCommitment(local bool) bool { perspective string ) - if local { + if issuer.IsLocal() { perspective = "local" // There are local updates pending if our local update log is @@ -5091,7 +5121,7 @@ func (lc *LightningChannel) RevokeCurrentCommitment() (*lnwire.RevokeAndAck, // Additionally, generate a channel delta for this state transition for // persistent storage. chainTail := lc.localCommitChain.tail() - newCommitment := chainTail.toDiskCommit(true) + newCommitment := chainTail.toDiskCommit(lntypes.Local) // Get the unsigned acked remotes updates that are currently in memory. // We need them after a restart to sync our remote commitment with what @@ -5501,7 +5531,7 @@ func (lc *LightningChannel) addHTLC(htlc *lnwire.UpdateAddHTLC, // commitment tx. // // NOTE: This over-estimates the dust exposure. -func (lc *LightningChannel) GetDustSum(remote bool, +func (lc *LightningChannel) GetDustSum(whoseCommit lntypes.ChannelParty, dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi { lc.RLock() @@ -5511,7 +5541,7 @@ func (lc *LightningChannel) GetDustSum(remote bool, dustLimit := lc.channelState.LocalChanCfg.DustLimit commit := lc.channelState.LocalCommitment - if remote { + if whoseCommit.IsRemote() { // Calculate dust sum on the remote's commitment. dustLimit = lc.channelState.RemoteChanCfg.DustLimit commit = lc.channelState.RemoteCommitment @@ -5535,7 +5565,7 @@ func (lc *LightningChannel) GetDustSum(remote bool, // If the satoshi amount is under the dust limit, add the msat // amount to the dust sum. if HtlcIsDust( - chanType, false, !remote, feeRate, amt, dustLimit, + chanType, false, whoseCommit, feeRate, amt, dustLimit, ) { dustSum += pd.Amount @@ -5554,7 +5584,8 @@ func (lc *LightningChannel) GetDustSum(remote bool, // If the satoshi amount is under the dust limit, add the msat // amount to the dust sum. if HtlcIsDust( - chanType, true, !remote, feeRate, amt, dustLimit, + chanType, true, whoseCommit, feeRate, + amt, dustLimit, ) { dustSum += pd.Amount @@ -5641,7 +5672,7 @@ func (lc *LightningChannel) validateAddHtlc(pd *PaymentDescriptor, // First we'll check whether this HTLC can be added to the remote // commitment transaction without violation any of the constraints. err := lc.validateCommitmentSanity( - remoteACKedIndex, lc.localUpdateLog.logIndex, true, + remoteACKedIndex, lc.localUpdateLog.logIndex, lntypes.Remote, buffer, pd, nil, ) if err != nil { @@ -5655,7 +5686,7 @@ func (lc *LightningChannel) validateAddHtlc(pd *PaymentDescriptor, // possible for us to add the HTLC. err = lc.validateCommitmentSanity( lc.remoteUpdateLog.logIndex, lc.localUpdateLog.logIndex, - false, buffer, pd, nil, + lntypes.Local, buffer, pd, nil, ) if err != nil { return err @@ -5696,8 +5727,8 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, err // we use it here. The current lightning protocol does not allow to // reject ADDs already sent by the peer. err := lc.validateCommitmentSanity( - lc.remoteUpdateLog.logIndex, localACKedIndex, false, NoBuffer, - nil, pd, + lc.remoteUpdateLog.logIndex, localACKedIndex, lntypes.Local, + NoBuffer, nil, pd, ) if err != nil { return 0, err @@ -6195,9 +6226,9 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si // First, we'll generate the commitment point and the revocation point // so we can re-construct the HTLC state and also our payment key. - isOurCommit := false + commitType := lntypes.Remote keyRing := DeriveCommitmentKeys( - commitPoint, isOurCommit, chanState.ChanType, + commitPoint, commitType, chanState.ChanType, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) @@ -6209,7 +6240,7 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si } isRemoteInitiator := !chanState.IsInitiator htlcResolutions, err := extractHtlcResolutions( - chainfee.SatPerKWeight(remoteCommit.FeePerKw), isOurCommit, + chainfee.SatPerKWeight(remoteCommit.FeePerKw), commitType, signer, remoteCommit.Htlcs, keyRing, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, commitSpend.SpendingTx, chanState.ChanType, isRemoteInitiator, leaseExpiry, @@ -6328,7 +6359,7 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si } anchorResolution, err := NewAnchorResolution( - chanState, commitTxBroadcast, keyRing, false, + chanState, commitTxBroadcast, keyRing, lntypes.Remote, ) if err != nil { return nil, err @@ -6465,7 +6496,7 @@ func newOutgoingHtlcResolution(signer input.Signer, localChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx, htlc *channeldb.HTLC, keyRing *CommitmentKeyRing, feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32, - localCommit, isCommitFromInitiator bool, + whoseCommit lntypes.ChannelParty, isCommitFromInitiator bool, chanType channeldb.ChannelType) (*OutgoingHtlcResolution, error) { op := wire.OutPoint{ @@ -6476,7 +6507,7 @@ func newOutgoingHtlcResolution(signer input.Signer, // First, we'll re-generate the script used to send the HTLC to the // remote party within their commitment transaction. htlcScriptInfo, err := genHtlcScript( - chanType, false, localCommit, htlc.RefundTimeout, htlc.RHash, + chanType, false, whoseCommit, htlc.RefundTimeout, htlc.RHash, keyRing, ) if err != nil { @@ -6497,7 +6528,7 @@ func newOutgoingHtlcResolution(signer input.Signer, // If we're spending this HTLC output from the remote node's // commitment, then we won't need to go to the second level as our // outputs don't have a CSV delay. - if !localCommit { + if whoseCommit.IsRemote() { // With the script generated, we can completely populated the // SignDescriptor needed to sweep the output. prevFetcher := txscript.NewCannedPrevOutputFetcher( @@ -6717,7 +6748,8 @@ func newIncomingHtlcResolution(signer input.Signer, localChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx, htlc *channeldb.HTLC, keyRing *CommitmentKeyRing, feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32, - localCommit, isCommitFromInitiator bool, chanType channeldb.ChannelType) ( + whoseCommit lntypes.ChannelParty, isCommitFromInitiator bool, + chanType channeldb.ChannelType) ( *IncomingHtlcResolution, error) { op := wire.OutPoint{ @@ -6728,7 +6760,7 @@ func newIncomingHtlcResolution(signer input.Signer, // First, we'll re-generate the script the remote party used to // send the HTLC to us in their commitment transaction. scriptInfo, err := genHtlcScript( - chanType, true, localCommit, htlc.RefundTimeout, htlc.RHash, + chanType, true, whoseCommit, htlc.RefundTimeout, htlc.RHash, keyRing, ) if err != nil { @@ -6749,7 +6781,7 @@ func newIncomingHtlcResolution(signer input.Signer, // If we're spending this output from the remote node's commitment, // then we can skip the second layer and spend the output directly. - if !localCommit { + if whoseCommit.IsRemote() { // With the script generated, we can completely populated the // SignDescriptor needed to sweep the output. prevFetcher := txscript.NewCannedPrevOutputFetcher( @@ -6976,8 +7008,9 @@ func (r *OutgoingHtlcResolution) HtlcPoint() wire.OutPoint { // extractHtlcResolutions creates a series of outgoing HTLC resolutions, and // the local key used when generating the HTLC scrips. This function is to be // used in two cases: force close, or a unilateral close. -func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, - signer input.Signer, htlcs []channeldb.HTLC, keyRing *CommitmentKeyRing, +func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, + whoseCommit lntypes.ChannelParty, signer input.Signer, + htlcs []channeldb.HTLC, keyRing *CommitmentKeyRing, localChanCfg, remoteChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx, chanType channeldb.ChannelType, isCommitFromInitiator bool, leaseExpiry uint32) (*HtlcResolutions, error) { @@ -6985,7 +7018,7 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, // TODO(roasbeef): don't need to swap csv delay? dustLimit := remoteChanCfg.DustLimit csvDelay := remoteChanCfg.CsvDelay - if ourCommit { + if whoseCommit.IsLocal() { dustLimit = localChanCfg.DustLimit csvDelay = localChanCfg.CsvDelay } @@ -6999,7 +7032,7 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, // transaction, as these don't have a corresponding output // within the commitment transaction. if HtlcIsDust( - chanType, htlc.Incoming, ourCommit, feePerKw, + chanType, htlc.Incoming, whoseCommit, feePerKw, htlc.Amt.ToSatoshis(), dustLimit, ) { @@ -7014,7 +7047,7 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, ihr, err := newIncomingHtlcResolution( signer, localChanCfg, commitTx, &htlc, keyRing, feePerKw, uint32(csvDelay), leaseExpiry, - ourCommit, isCommitFromInitiator, chanType, + whoseCommit, isCommitFromInitiator, chanType, ) if err != nil { return nil, fmt.Errorf("incoming resolution "+ @@ -7027,7 +7060,7 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, ohr, err := newOutgoingHtlcResolution( signer, localChanCfg, commitTx, &htlc, keyRing, - feePerKw, uint32(csvDelay), leaseExpiry, ourCommit, + feePerKw, uint32(csvDelay), leaseExpiry, whoseCommit, isCommitFromInitiator, chanType, ) if err != nil { @@ -7163,7 +7196,7 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, } commitPoint := input.ComputeCommitmentPoint(revocation[:]) keyRing := DeriveCommitmentKeys( - commitPoint, true, chanState.ChanType, + commitPoint, lntypes.Local, chanState.ChanType, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) @@ -7261,8 +7294,8 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, // use what we have in our latest state when extracting resolutions. localCommit := chanState.LocalCommitment htlcResolutions, err := extractHtlcResolutions( - chainfee.SatPerKWeight(localCommit.FeePerKw), true, signer, - localCommit.Htlcs, keyRing, &chanState.LocalChanCfg, + chainfee.SatPerKWeight(localCommit.FeePerKw), lntypes.Local, + signer, localCommit.Htlcs, keyRing, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, commitTx, chanState.ChanType, chanState.IsInitiator, leaseExpiry, ) @@ -7271,7 +7304,7 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, } anchorResolution, err := NewAnchorResolution( - chanState, commitTx, keyRing, true, + chanState, commitTx, keyRing, lntypes.Local, ) if err != nil { return nil, fmt.Errorf("unable to gen anchor "+ @@ -7561,12 +7594,12 @@ func (lc *LightningChannel) NewAnchorResolutions() (*AnchorResolutions, } localCommitPoint := input.ComputeCommitmentPoint(revocation[:]) localKeyRing := DeriveCommitmentKeys( - localCommitPoint, true, lc.channelState.ChanType, + localCommitPoint, lntypes.Local, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, ) localRes, err := NewAnchorResolution( lc.channelState, lc.channelState.LocalCommitment.CommitTx, - localKeyRing, true, + localKeyRing, lntypes.Local, ) if err != nil { return nil, err @@ -7575,13 +7608,13 @@ func (lc *LightningChannel) NewAnchorResolutions() (*AnchorResolutions, // Add anchor for remote commitment tx, if any. remoteKeyRing := DeriveCommitmentKeys( - lc.channelState.RemoteCurrentRevocation, false, + lc.channelState.RemoteCurrentRevocation, lntypes.Remote, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, ) remoteRes, err := NewAnchorResolution( lc.channelState, lc.channelState.RemoteCommitment.CommitTx, - remoteKeyRing, false, + remoteKeyRing, lntypes.Remote, ) if err != nil { return nil, err @@ -7596,14 +7629,14 @@ func (lc *LightningChannel) NewAnchorResolutions() (*AnchorResolutions, if remotePendingCommit != nil { pendingRemoteKeyRing := DeriveCommitmentKeys( - lc.channelState.RemoteNextRevocation, false, + lc.channelState.RemoteNextRevocation, lntypes.Remote, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, ) remotePendingRes, err := NewAnchorResolution( lc.channelState, remotePendingCommit.Commitment.CommitTx, - pendingRemoteKeyRing, false, + pendingRemoteKeyRing, lntypes.Remote, ) if err != nil { return nil, err @@ -7618,7 +7651,7 @@ func (lc *LightningChannel) NewAnchorResolutions() (*AnchorResolutions, // local anchor. func NewAnchorResolution(chanState *channeldb.OpenChannel, commitTx *wire.MsgTx, keyRing *CommitmentKeyRing, - isLocalCommit bool) (*AnchorResolution, error) { + whoseCommit lntypes.ChannelParty) (*AnchorResolution, error) { // Return nil resolution if the channel has no anchors. if !chanState.ChanType.HasAnchors() { @@ -7636,7 +7669,7 @@ func NewAnchorResolution(chanState *channeldb.OpenChannel, if err != nil { return nil, err } - if chanState.ChanType.IsTaproot() && !isLocalCommit { + if chanState.ChanType.IsTaproot() && whoseCommit.IsRemote() { //nolint:ineffassign localAnchor, remoteAnchor = remoteAnchor, localAnchor } @@ -7690,7 +7723,7 @@ func NewAnchorResolution(chanState *channeldb.OpenChannel, // For anchor outputs with taproot channels, the key desc is // also different: we'll just re-use our local delay base point // (which becomes our to local output). - if isLocalCommit { + if whoseCommit.IsLocal() { // In addition to the sign method, we'll also need to // ensure that the single tweak is set, as with the // current formulation, we'll need to use two levels of @@ -7777,12 +7810,12 @@ func (lc *LightningChannel) availableBalance( // add updates concurrently, causing our balance to go down if we're // the initiator, but this is a problem on the protocol level. ourLocalCommitBalance, commitWeight := lc.availableCommitmentBalance( - htlcView, false, buffer, + htlcView, lntypes.Local, buffer, ) // Do the same calculation from the remote commitment point of view. ourRemoteCommitBalance, _ := lc.availableCommitmentBalance( - htlcView, true, buffer, + htlcView, lntypes.Remote, buffer, ) // Return which ever balance is lowest. @@ -7800,15 +7833,16 @@ func (lc *LightningChannel) availableBalance( // commitment, increasing the commitment fee we must pay as an initiator, // eating into our balance. It will make sure we won't violate the channel // reserve constraints for this amount. -func (lc *LightningChannel) availableCommitmentBalance( - view *htlcView, remoteChain bool, - buffer BufferType) (lnwire.MilliSatoshi, lntypes.WeightUnit) { +func (lc *LightningChannel) availableCommitmentBalance(view *htlcView, + whoseCommitChain lntypes.ChannelParty, buffer BufferType) ( + lnwire.MilliSatoshi, lntypes.WeightUnit) { // Compute the current balances for this commitment. This will take // into account HTLCs to determine the commit weight, which the // initiator must pay the fee for. ourBalance, theirBalance, commitWeight, filteredView, err := lc.computeView( - view, remoteChain, false, fn.None[chainfee.SatPerKWeight](), + view, whoseCommitChain, false, + fn.None[chainfee.SatPerKWeight](), ) if err != nil { lc.log.Errorf("Unable to fetch available balance: %v", err) @@ -7894,7 +7928,7 @@ func (lc *LightningChannel) availableCommitmentBalance( // If we are looking at the remote commitment, we must use the remote // dust limit and the fee for adding an HTLC success transaction. - if remoteChain { + if whoseCommitChain.IsRemote() { dustlimit = lnwire.NewMSatFromSatoshis( lc.channelState.RemoteChanCfg.DustLimit, ) @@ -8031,7 +8065,7 @@ func (lc *LightningChannel) CommitFeeTotalAt( // Compute the local commitment's weight. _, _, localWeight, _, err := lc.computeView( - localHtlcView, false, false, dryRunFee, + localHtlcView, lntypes.Local, false, dryRunFee, ) if err != nil { return 0, 0, err @@ -8045,7 +8079,7 @@ func (lc *LightningChannel) CommitFeeTotalAt( // Compute the remote commitment's weight. _, _, remoteWeight, _, err := lc.computeView( - remoteHtlcView, true, false, dryRunFee, + remoteHtlcView, lntypes.Remote, false, dryRunFee, ) if err != nil { return 0, 0, err @@ -8455,12 +8489,12 @@ func (lc *LightningChannel) MarkBorked() error { // for it to confirm before taking any further action. It takes a boolean which // indicates whether we initiated the close. func (lc *LightningChannel) MarkCommitmentBroadcasted(tx *wire.MsgTx, - locallyInitiated bool) error { + closer lntypes.ChannelParty) error { lc.Lock() defer lc.Unlock() - return lc.channelState.MarkCommitmentBroadcasted(tx, locallyInitiated) + return lc.channelState.MarkCommitmentBroadcasted(tx, closer) } // MarkCoopBroadcasted marks the channel as a cooperative close transaction has @@ -8468,12 +8502,12 @@ func (lc *LightningChannel) MarkCommitmentBroadcasted(tx *wire.MsgTx, // taking any further action. It takes a locally initiated bool which is true // if we initiated the cooperative close. func (lc *LightningChannel) MarkCoopBroadcasted(tx *wire.MsgTx, - localInitiated bool) error { + closer lntypes.ChannelParty) error { lc.Lock() defer lc.Unlock() - return lc.channelState.MarkCoopBroadcasted(tx, localInitiated) + return lc.channelState.MarkCoopBroadcasted(tx, closer) } // MarkShutdownSent persists the given ShutdownInfo. The existence of the diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 185bdf87a6..330d8d130e 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -5196,7 +5196,7 @@ func TestChanCommitWeightDustHtlcs(t *testing.T) { lc.localUpdateLog.logIndex) _, w := lc.availableCommitmentBalance( - htlcView, true, FeeBuffer, + htlcView, lntypes.Remote, FeeBuffer, ) return w @@ -7985,11 +7985,11 @@ func TestChannelFeeRateFloor(t *testing.T) { // TestFetchParent tests lookup of an entry's parent in the appropriate log. func TestFetchParent(t *testing.T) { tests := []struct { - name string - remoteChain bool - remoteLog bool - localEntries []*PaymentDescriptor - remoteEntries []*PaymentDescriptor + name string + whoseCommitChain lntypes.ChannelParty + whoseUpdateLog lntypes.ChannelParty + localEntries []*PaymentDescriptor + remoteEntries []*PaymentDescriptor // parentIndex is the parent index of the entry that we will // lookup with fetch parent. @@ -8003,22 +8003,22 @@ func TestFetchParent(t *testing.T) { expectedIndex uint64 }{ { - name: "not found in remote log", - localEntries: nil, - remoteEntries: nil, - remoteChain: true, - remoteLog: true, - parentIndex: 0, - expectErr: true, + name: "not found in remote log", + localEntries: nil, + remoteEntries: nil, + whoseCommitChain: lntypes.Remote, + whoseUpdateLog: lntypes.Remote, + parentIndex: 0, + expectErr: true, }, { - name: "not found in local log", - localEntries: nil, - remoteEntries: nil, - remoteChain: false, - remoteLog: false, - parentIndex: 0, - expectErr: true, + name: "not found in local log", + localEntries: nil, + remoteEntries: nil, + whoseCommitChain: lntypes.Local, + whoseUpdateLog: lntypes.Local, + parentIndex: 0, + expectErr: true, }, { name: "remote log + chain, remote add height 0", @@ -8038,10 +8038,10 @@ func TestFetchParent(t *testing.T) { addCommitHeightRemote: 0, }, }, - remoteChain: true, - remoteLog: true, - parentIndex: 1, - expectErr: true, + whoseCommitChain: lntypes.Remote, + whoseUpdateLog: lntypes.Remote, + parentIndex: 1, + expectErr: true, }, { name: "remote log, local chain, local add height 0", @@ -8060,11 +8060,11 @@ func TestFetchParent(t *testing.T) { addCommitHeightRemote: 100, }, }, - localEntries: nil, - remoteChain: false, - remoteLog: true, - parentIndex: 1, - expectErr: true, + localEntries: nil, + whoseCommitChain: lntypes.Local, + whoseUpdateLog: lntypes.Remote, + parentIndex: 1, + expectErr: true, }, { name: "local log + chain, local add height 0", @@ -8083,11 +8083,11 @@ func TestFetchParent(t *testing.T) { addCommitHeightRemote: 100, }, }, - remoteEntries: nil, - remoteChain: false, - remoteLog: false, - parentIndex: 1, - expectErr: true, + remoteEntries: nil, + whoseCommitChain: lntypes.Local, + whoseUpdateLog: lntypes.Local, + parentIndex: 1, + expectErr: true, }, { @@ -8107,11 +8107,11 @@ func TestFetchParent(t *testing.T) { addCommitHeightRemote: 0, }, }, - remoteEntries: nil, - remoteChain: true, - remoteLog: false, - parentIndex: 1, - expectErr: true, + remoteEntries: nil, + whoseCommitChain: lntypes.Remote, + whoseUpdateLog: lntypes.Local, + parentIndex: 1, + expectErr: true, }, { name: "remote log found", @@ -8131,11 +8131,11 @@ func TestFetchParent(t *testing.T) { addCommitHeightRemote: 100, }, }, - remoteChain: true, - remoteLog: true, - parentIndex: 1, - expectErr: false, - expectedIndex: 2, + whoseCommitChain: lntypes.Remote, + whoseUpdateLog: lntypes.Remote, + parentIndex: 1, + expectErr: false, + expectedIndex: 2, }, { name: "local log found", @@ -8154,12 +8154,12 @@ func TestFetchParent(t *testing.T) { addCommitHeightRemote: 100, }, }, - remoteEntries: nil, - remoteChain: false, - remoteLog: false, - parentIndex: 1, - expectErr: false, - expectedIndex: 2, + remoteEntries: nil, + whoseCommitChain: lntypes.Local, + whoseUpdateLog: lntypes.Local, + parentIndex: 1, + expectErr: false, + expectedIndex: 2, }, } @@ -8186,8 +8186,8 @@ func TestFetchParent(t *testing.T) { &PaymentDescriptor{ ParentIndex: test.parentIndex, }, - test.remoteChain, - test.remoteLog, + test.whoseCommitChain, + test.whoseUpdateLog, ) gotErr := err != nil if test.expectErr != gotErr { @@ -8245,11 +8245,11 @@ func TestEvaluateView(t *testing.T) { ) tests := []struct { - name string - ourHtlcs []*PaymentDescriptor - theirHtlcs []*PaymentDescriptor - remoteChain bool - mutateState bool + name string + ourHtlcs []*PaymentDescriptor + theirHtlcs []*PaymentDescriptor + whoseCommitChain lntypes.ChannelParty + mutateState bool // ourExpectedHtlcs is the set of our htlcs that we expect in // the htlc view once it has been evaluated. We just store @@ -8276,9 +8276,9 @@ func TestEvaluateView(t *testing.T) { expectSent lnwire.MilliSatoshi }{ { - name: "our fee update is applied", - remoteChain: false, - mutateState: false, + name: "our fee update is applied", + whoseCommitChain: lntypes.Local, + mutateState: false, ourHtlcs: []*PaymentDescriptor{ { Amount: ourFeeUpdateAmt, @@ -8293,10 +8293,10 @@ func TestEvaluateView(t *testing.T) { expectSent: 0, }, { - name: "their fee update is applied", - remoteChain: false, - mutateState: false, - ourHtlcs: []*PaymentDescriptor{}, + name: "their fee update is applied", + whoseCommitChain: lntypes.Local, + mutateState: false, + ourHtlcs: []*PaymentDescriptor{}, theirHtlcs: []*PaymentDescriptor{ { Amount: theirFeeUpdateAmt, @@ -8311,9 +8311,9 @@ func TestEvaluateView(t *testing.T) { }, { // We expect unresolved htlcs to to remain in the view. - name: "htlcs adds without settles", - remoteChain: false, - mutateState: false, + name: "htlcs adds without settles", + whoseCommitChain: lntypes.Local, + mutateState: false, ourHtlcs: []*PaymentDescriptor{ { HtlcIndex: 0, @@ -8345,9 +8345,9 @@ func TestEvaluateView(t *testing.T) { expectSent: 0, }, { - name: "our htlc settled, state mutated", - remoteChain: false, - mutateState: true, + name: "our htlc settled, state mutated", + whoseCommitChain: lntypes.Local, + mutateState: true, ourHtlcs: []*PaymentDescriptor{ { HtlcIndex: 0, @@ -8380,9 +8380,9 @@ func TestEvaluateView(t *testing.T) { expectSent: htlcAddAmount, }, { - name: "our htlc settled, state not mutated", - remoteChain: false, - mutateState: false, + name: "our htlc settled, state not mutated", + whoseCommitChain: lntypes.Local, + mutateState: false, ourHtlcs: []*PaymentDescriptor{ { HtlcIndex: 0, @@ -8415,9 +8415,9 @@ func TestEvaluateView(t *testing.T) { expectSent: 0, }, { - name: "their htlc settled, state mutated", - remoteChain: false, - mutateState: true, + name: "their htlc settled, state mutated", + whoseCommitChain: lntypes.Local, + mutateState: true, ourHtlcs: []*PaymentDescriptor{ { HtlcIndex: 0, @@ -8458,9 +8458,10 @@ func TestEvaluateView(t *testing.T) { expectSent: 0, }, { - name: "their htlc settled, state not mutated", - remoteChain: false, - mutateState: false, + name: "their htlc settled, state not mutated", + + whoseCommitChain: lntypes.Local, + mutateState: false, ourHtlcs: []*PaymentDescriptor{ { HtlcIndex: 0, @@ -8543,7 +8544,7 @@ func TestEvaluateView(t *testing.T) { // Evaluate the htlc view, mutate as test expects. result, err := lc.evaluateHTLCView( view, &ourBalance, &theirBalance, nextHeight, - test.remoteChain, test.mutateState, + test.whoseCommitChain, test.mutateState, ) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -8631,12 +8632,12 @@ func TestProcessFeeUpdate(t *testing.T) { ) tests := []struct { - name string - startHeights heights - expectedHeights heights - remoteChain bool - mutate bool - expectedFee chainfee.SatPerKWeight + name string + startHeights heights + expectedHeights heights + whoseCommitChain lntypes.ChannelParty + mutate bool + expectedFee chainfee.SatPerKWeight }{ { // Looking at local chain, local add is non-zero so @@ -8654,9 +8655,9 @@ func TestProcessFeeUpdate(t *testing.T) { remoteAdd: 0, remoteRemove: height, }, - remoteChain: false, - mutate: false, - expectedFee: feePerKw, + whoseCommitChain: lntypes.Local, + mutate: false, + expectedFee: feePerKw, }, { // Looking at local chain, local add is zero so the @@ -8675,9 +8676,9 @@ func TestProcessFeeUpdate(t *testing.T) { remoteAdd: height, remoteRemove: 0, }, - remoteChain: false, - mutate: false, - expectedFee: ourFeeUpdatePerSat, + whoseCommitChain: lntypes.Local, + mutate: false, + expectedFee: ourFeeUpdatePerSat, }, { // Looking at remote chain, the remote add height is @@ -8696,9 +8697,9 @@ func TestProcessFeeUpdate(t *testing.T) { remoteAdd: 0, remoteRemove: 0, }, - remoteChain: true, - mutate: false, - expectedFee: ourFeeUpdatePerSat, + whoseCommitChain: lntypes.Remote, + mutate: false, + expectedFee: ourFeeUpdatePerSat, }, { // Looking at remote chain, the remote add height is @@ -8717,9 +8718,9 @@ func TestProcessFeeUpdate(t *testing.T) { remoteAdd: height, remoteRemove: 0, }, - remoteChain: true, - mutate: false, - expectedFee: feePerKw, + whoseCommitChain: lntypes.Remote, + mutate: false, + expectedFee: feePerKw, }, { // Local add height is non-zero, so the update has @@ -8738,9 +8739,9 @@ func TestProcessFeeUpdate(t *testing.T) { remoteAdd: 0, remoteRemove: height, }, - remoteChain: false, - mutate: true, - expectedFee: feePerKw, + whoseCommitChain: lntypes.Local, + mutate: true, + expectedFee: feePerKw, }, { // Local add is zero and we are looking at our local @@ -8760,9 +8761,9 @@ func TestProcessFeeUpdate(t *testing.T) { remoteAdd: 0, remoteRemove: 0, }, - remoteChain: false, - mutate: true, - expectedFee: ourFeeUpdatePerSat, + whoseCommitChain: lntypes.Local, + mutate: true, + expectedFee: ourFeeUpdatePerSat, }, } @@ -8786,7 +8787,7 @@ func TestProcessFeeUpdate(t *testing.T) { feePerKw: chainfee.SatPerKWeight(feePerKw), } processFeeUpdate( - update, nextHeight, test.remoteChain, + update, nextHeight, test.whoseCommitChain, test.mutate, view, ) @@ -8841,7 +8842,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { tests := []struct { name string startHeights heights - remoteChain bool + whoseCommitChain lntypes.ChannelParty isIncoming bool mutateState bool ourExpectedBalance lnwire.MilliSatoshi @@ -8857,7 +8858,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: false, mutateState: false, ourExpectedBalance: startBalance, @@ -8878,7 +8879,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: false, + whoseCommitChain: lntypes.Local, isIncoming: false, mutateState: false, ourExpectedBalance: startBalance, @@ -8899,7 +8900,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: false, + whoseCommitChain: lntypes.Local, isIncoming: true, mutateState: false, ourExpectedBalance: startBalance, @@ -8920,7 +8921,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: false, + whoseCommitChain: lntypes.Local, isIncoming: true, mutateState: true, ourExpectedBalance: startBalance, @@ -8942,7 +8943,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: false, mutateState: false, ourExpectedBalance: startBalance - updateAmount, @@ -8963,7 +8964,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: false, mutateState: true, ourExpectedBalance: startBalance - updateAmount, @@ -8984,7 +8985,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: removeHeight, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: false, mutateState: false, ourExpectedBalance: startBalance, @@ -9005,7 +9006,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: removeHeight, remoteRemove: 0, }, - remoteChain: false, + whoseCommitChain: lntypes.Local, isIncoming: false, mutateState: false, ourExpectedBalance: startBalance, @@ -9028,7 +9029,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: true, mutateState: false, ourExpectedBalance: startBalance + updateAmount, @@ -9051,7 +9052,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: false, mutateState: false, ourExpectedBalance: startBalance, @@ -9074,7 +9075,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: true, mutateState: false, ourExpectedBalance: startBalance, @@ -9097,7 +9098,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: false, mutateState: false, ourExpectedBalance: startBalance + updateAmount, @@ -9122,7 +9123,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: false, + whoseCommitChain: lntypes.Local, isIncoming: true, mutateState: true, ourExpectedBalance: startBalance + updateAmount, @@ -9147,7 +9148,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: true, mutateState: true, ourExpectedBalance: startBalance + updateAmount, @@ -9196,7 +9197,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { process( update, &ourBalance, &theirBalance, nextHeight, - test.remoteChain, test.isIncoming, + test.whoseCommitChain, test.isIncoming, test.mutateState, ) @@ -9752,11 +9753,11 @@ func testGetDustSum(t *testing.T, chantype channeldb.ChannelType) { expRemote lnwire.MilliSatoshi) { localDustSum := c.GetDustSum( - false, fn.None[chainfee.SatPerKWeight](), + lntypes.Local, fn.None[chainfee.SatPerKWeight](), ) require.Equal(t, expLocal, localDustSum) remoteDustSum := c.GetDustSum( - true, fn.None[chainfee.SatPerKWeight](), + lntypes.Remote, fn.None[chainfee.SatPerKWeight](), ) require.Equal(t, expRemote, remoteDustSum) } @@ -9910,8 +9911,9 @@ func deriveDummyRetributionParams(chanState *channeldb.OpenChannel) (uint32, config := chanState.RemoteChanCfg commitHash := chanState.RemoteCommitment.CommitTx.TxHash() keyRing := DeriveCommitmentKeys( - config.RevocationBasePoint.PubKey, false, chanState.ChanType, - &chanState.LocalChanCfg, &chanState.RemoteChanCfg, + config.RevocationBasePoint.PubKey, lntypes.Remote, + chanState.ChanType, &chanState.LocalChanCfg, + &chanState.RemoteChanCfg, ) leaseExpiry := chanState.ThawHeight return leaseExpiry, keyRing, commitHash @@ -10378,7 +10380,7 @@ func TestExtractPayDescs(t *testing.T) { // NOTE: we use nil commitment key rings to avoid checking the htlc // scripts(`genHtlcScript`) as it should be tested independently. incomingPDs, outgoingPDs, err := lnChan.extractPayDescs( - 0, 0, htlcs, nil, nil, true, + 0, 0, htlcs, nil, nil, lntypes.Local, ) require.NoError(t, err) diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index 96af8d7cf8..2cf58f494e 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -103,7 +103,7 @@ type CommitmentKeyRing struct { // of channel, and whether the commitment transaction is ours or the remote // peer's. func DeriveCommitmentKeys(commitPoint *btcec.PublicKey, - isOurCommit bool, chanType channeldb.ChannelType, + whoseCommit lntypes.ChannelParty, chanType channeldb.ChannelType, localChanCfg, remoteChanCfg *channeldb.ChannelConfig) *CommitmentKeyRing { tweaklessCommit := chanType.IsTweakless() @@ -111,7 +111,7 @@ func DeriveCommitmentKeys(commitPoint *btcec.PublicKey, // Depending on if this is our commit or not, we'll choose the correct // base point. localBasePoint := localChanCfg.PaymentBasePoint - if isOurCommit { + if whoseCommit.IsLocal() { localBasePoint = localChanCfg.DelayBasePoint } @@ -144,7 +144,7 @@ func DeriveCommitmentKeys(commitPoint *btcec.PublicKey, toRemoteBasePoint *btcec.PublicKey revocationBasePoint *btcec.PublicKey ) - if isOurCommit { + if whoseCommit.IsLocal() { toLocalBasePoint = localChanCfg.DelayBasePoint.PubKey toRemoteBasePoint = remoteChanCfg.PaymentBasePoint.PubKey revocationBasePoint = remoteChanCfg.RevocationBasePoint.PubKey @@ -169,7 +169,7 @@ func DeriveCommitmentKeys(commitPoint *btcec.PublicKey, // If this is not our commitment, the above ToRemoteKey will be // ours, and we blank out the local commitment tweak to // indicate that the key should not be tweaked when signing. - if !isOurCommit { + if whoseCommit.IsRemote() { keyRing.LocalCommitKeyTweak = nil } } else { @@ -686,20 +686,20 @@ type unsignedCommitmentTx struct { // passed in balances should be balances *before* subtracting any commitment // fees, but after anchor outputs. func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, - theirBalance lnwire.MilliSatoshi, isOurs bool, + theirBalance lnwire.MilliSatoshi, whoseCommit lntypes.ChannelParty, feePerKw chainfee.SatPerKWeight, height uint64, filteredHTLCView *htlcView, keyRing *CommitmentKeyRing) (*unsignedCommitmentTx, error) { dustLimit := cb.chanState.LocalChanCfg.DustLimit - if !isOurs { + if whoseCommit.IsRemote() { dustLimit = cb.chanState.RemoteChanCfg.DustLimit } numHTLCs := int64(0) for _, htlc := range filteredHTLCView.ourUpdates { if HtlcIsDust( - cb.chanState.ChanType, false, isOurs, feePerKw, + cb.chanState.ChanType, false, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -710,7 +710,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } for _, htlc := range filteredHTLCView.theirUpdates { if HtlcIsDust( - cb.chanState.ChanType, true, isOurs, feePerKw, + cb.chanState.ChanType, true, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -763,7 +763,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, if cb.chanState.ChanType.HasLeaseExpiration() { leaseExpiry = cb.chanState.ThawHeight } - if isOurs { + if whoseCommit.IsLocal() { commitTx, err = CreateCommitTx( cb.chanState.ChanType, fundingTxIn(cb.chanState), keyRing, &cb.chanState.LocalChanCfg, &cb.chanState.RemoteChanCfg, @@ -794,7 +794,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, cltvs := make([]uint32, len(commitTx.TxOut)) for _, htlc := range filteredHTLCView.ourUpdates { if HtlcIsDust( - cb.chanState.ChanType, false, isOurs, feePerKw, + cb.chanState.ChanType, false, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -802,7 +802,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } err := addHTLC( - commitTx, isOurs, false, htlc, keyRing, + commitTx, whoseCommit, false, htlc, keyRing, cb.chanState.ChanType, ) if err != nil { @@ -812,7 +812,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } for _, htlc := range filteredHTLCView.theirUpdates { if HtlcIsDust( - cb.chanState.ChanType, true, isOurs, feePerKw, + cb.chanState.ChanType, true, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -820,7 +820,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } err := addHTLC( - commitTx, isOurs, true, htlc, keyRing, + commitTx, whoseCommit, true, htlc, keyRing, cb.chanState.ChanType, ) if err != nil { @@ -1003,8 +1003,9 @@ func CoopCloseBalance(chanType channeldb.ChannelType, isInitiator bool, // genSegwitV0HtlcScript generates the HTLC scripts for a normal segwit v0 // channel. func genSegwitV0HtlcScript(chanType channeldb.ChannelType, - isIncoming, ourCommit bool, timeout uint32, rHash [32]byte, - keyRing *CommitmentKeyRing) (*WitnessScriptDesc, error) { + isIncoming bool, whoseCommit lntypes.ChannelParty, timeout uint32, + rHash [32]byte, keyRing *CommitmentKeyRing, +) (*WitnessScriptDesc, error) { var ( witnessScript []byte @@ -1024,7 +1025,7 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType, // The HTLC is paying to us, and being applied to our commitment // transaction. So we need to use the receiver's version of the HTLC // script. - case isIncoming && ourCommit: + case isIncoming && whoseCommit.IsLocal(): witnessScript, err = input.ReceiverHTLCScript( timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, keyRing.RevocationKey, rHash[:], confirmedHtlcSpends, @@ -1033,7 +1034,7 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType, // We're being paid via an HTLC by the remote party, and the HTLC is // being added to their commitment transaction, so we use the sender's // version of the HTLC script. - case isIncoming && !ourCommit: + case isIncoming && whoseCommit.IsRemote(): witnessScript, err = input.SenderHTLCScript( keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, keyRing.RevocationKey, rHash[:], confirmedHtlcSpends, @@ -1042,7 +1043,7 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType, // We're sending an HTLC which is being added to our commitment // transaction. Therefore, we need to use the sender's version of the // HTLC script. - case !isIncoming && ourCommit: + case !isIncoming && whoseCommit.IsLocal(): witnessScript, err = input.SenderHTLCScript( keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, keyRing.RevocationKey, rHash[:], confirmedHtlcSpends, @@ -1051,7 +1052,7 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType, // Finally, we're paying the remote party via an HTLC, which is being // added to their commitment transaction. Therefore, we use the // receiver's version of the HTLC script. - case !isIncoming && !ourCommit: + case !isIncoming && whoseCommit.IsRemote(): witnessScript, err = input.ReceiverHTLCScript( timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, keyRing.RevocationKey, rHash[:], confirmedHtlcSpends, @@ -1076,9 +1077,9 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType, // genTaprootHtlcScript generates the HTLC scripts for a taproot+musig2 // channel. -func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, - rHash [32]byte, - keyRing *CommitmentKeyRing) (*input.HtlcScriptTree, error) { +func genTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, + timeout uint32, rHash [32]byte, keyRing *CommitmentKeyRing, +) (*input.HtlcScriptTree, error) { var ( htlcScriptTree *input.HtlcScriptTree @@ -1092,37 +1093,37 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, // The HTLC is paying to us, and being applied to our commitment // transaction. So we need to use the receiver's version of HTLC the // script. - case isIncoming && ourCommit: + case isIncoming && whoseCommit.IsLocal(): htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], whoseCommit, ) // We're being paid via an HTLC by the remote party, and the HTLC is // being added to their commitment transaction, so we use the sender's // version of the HTLC script. - case isIncoming && !ourCommit: + case isIncoming && whoseCommit.IsRemote(): htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], whoseCommit, ) // We're sending an HTLC which is being added to our commitment // transaction. Therefore, we need to use the sender's version of the // HTLC script. - case !isIncoming && ourCommit: + case !isIncoming && whoseCommit.IsLocal(): htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], whoseCommit, ) // Finally, we're paying the remote party via an HTLC, which is being // added to their commitment transaction. Therefore, we use the // receiver's version of the HTLC script. - case !isIncoming && !ourCommit: + case !isIncoming && whoseCommit.IsRemote(): htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], whoseCommit, ) } @@ -1135,19 +1136,20 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, // multiplexer for the various spending paths is returned. The script path that // we need to sign for the remote party (2nd level HTLCs) is also returned // along side the multiplexer. -func genHtlcScript(chanType channeldb.ChannelType, isIncoming, ourCommit bool, - timeout uint32, rHash [32]byte, keyRing *CommitmentKeyRing, +func genHtlcScript(chanType channeldb.ChannelType, isIncoming bool, + whoseCommit lntypes.ChannelParty, timeout uint32, rHash [32]byte, + keyRing *CommitmentKeyRing, ) (input.ScriptDescriptor, error) { if !chanType.IsTaproot() { return genSegwitV0HtlcScript( - chanType, isIncoming, ourCommit, timeout, rHash, + chanType, isIncoming, whoseCommit, timeout, rHash, keyRing, ) } return genTaprootHtlcScript( - isIncoming, ourCommit, timeout, rHash, keyRing, + isIncoming, whoseCommit, timeout, rHash, keyRing, ) } @@ -1158,7 +1160,7 @@ func genHtlcScript(chanType channeldb.ChannelType, isIncoming, ourCommit bool, // locate the added HTLC on the commitment transaction from the // PaymentDescriptor that generated it, the generated script is stored within // the descriptor itself. -func addHTLC(commitTx *wire.MsgTx, ourCommit bool, +func addHTLC(commitTx *wire.MsgTx, whoseCommit lntypes.ChannelParty, isIncoming bool, paymentDesc *PaymentDescriptor, keyRing *CommitmentKeyRing, chanType channeldb.ChannelType) error { @@ -1166,7 +1168,7 @@ func addHTLC(commitTx *wire.MsgTx, ourCommit bool, rHash := paymentDesc.RHash scriptInfo, err := genHtlcScript( - chanType, isIncoming, ourCommit, timeout, rHash, keyRing, + chanType, isIncoming, whoseCommit, timeout, rHash, keyRing, ) if err != nil { return err @@ -1180,7 +1182,7 @@ func addHTLC(commitTx *wire.MsgTx, ourCommit bool, // Store the pkScript of this particular PaymentDescriptor so we can // quickly locate it within the commitment transaction later. - if ourCommit { + if whoseCommit.IsLocal() { paymentDesc.ourPkScript = pkScript paymentDesc.ourWitnessScript = scriptInfo.WitnessScriptToSign() @@ -1211,7 +1213,7 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, // With the commitment point generated, we can now derive the king ring // which will be used to generate the output scripts. keyRing := DeriveCommitmentKeys( - commitmentPoint, false, chanState.ChanType, + commitmentPoint, lntypes.Remote, chanState.ChanType, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go index cf61606da5..a56bf1c217 100644 --- a/lnwallet/wallet.go +++ b/lnwallet/wallet.go @@ -25,6 +25,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chanfunding" "github.com/lightningnetwork/lnd/lnwallet/chanvalidate" @@ -1475,10 +1476,12 @@ func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount, leaseExpiry uint32) (*wire.MsgTx, *wire.MsgTx, error) { localCommitmentKeys := DeriveCommitmentKeys( - localCommitPoint, true, chanType, ourChanCfg, theirChanCfg, + localCommitPoint, lntypes.Local, chanType, ourChanCfg, + theirChanCfg, ) remoteCommitmentKeys := DeriveCommitmentKeys( - remoteCommitPoint, false, chanType, ourChanCfg, theirChanCfg, + remoteCommitPoint, lntypes.Remote, chanType, ourChanCfg, + theirChanCfg, ) ourCommitTx, err := CreateCommitTx( diff --git a/peer/brontide.go b/peer/brontide.go index 7a390cfd7c..27438daa6f 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -36,6 +36,7 @@ import ( "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lnpeer" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -1069,7 +1070,7 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( chanCloser, err := p.createChanCloser( lnChan, info.DeliveryScript.Val, feePerKw, nil, - info.LocalInitiator.Val, + info.Closer(), ) if err != nil { shutdownInfoErr = fmt.Errorf("unable to "+ @@ -2732,7 +2733,7 @@ func (p *Brontide) fetchActiveChanCloser(chanID lnwire.ChannelID) ( } chanCloser, err = p.createChanCloser( - channel, deliveryScript, feePerKw, nil, false, + channel, deliveryScript, feePerKw, nil, lntypes.Remote, ) if err != nil { p.log.Errorf("unable to create chan closer: %v", err) @@ -2969,12 +2970,13 @@ func (p *Brontide) restartCoopClose(lnChan *lnwallet.LightningChannel) ( // Determine whether we or the peer are the initiator of the coop // close attempt by looking at the channel's status. - locallyInitiated := c.HasChanStatus( - channeldb.ChanStatusLocalCloseInitiator, - ) + closingParty := lntypes.Remote + if c.HasChanStatus(channeldb.ChanStatusLocalCloseInitiator) { + closingParty = lntypes.Local + } chanCloser, err := p.createChanCloser( - lnChan, deliveryScript, feePerKw, nil, locallyInitiated, + lnChan, deliveryScript, feePerKw, nil, closingParty, ) if err != nil { p.log.Errorf("unable to create chan closer: %v", err) @@ -3003,7 +3005,7 @@ func (p *Brontide) restartCoopClose(lnChan *lnwallet.LightningChannel) ( func (p *Brontide) createChanCloser(channel *lnwallet.LightningChannel, deliveryScript lnwire.DeliveryAddress, fee chainfee.SatPerKWeight, req *htlcswitch.ChanClose, - locallyInitiated bool) (*chancloser.ChanCloser, error) { + closer lntypes.ChannelParty) (*chancloser.ChanCloser, error) { _, startingHeight, err := p.cfg.ChainIO.GetBestBlock() if err != nil { @@ -3039,7 +3041,7 @@ func (p *Brontide) createChanCloser(channel *lnwallet.LightningChannel, fee, uint32(startingHeight), req, - locallyInitiated, + closer, ) return chanCloser, nil @@ -3096,7 +3098,8 @@ func (p *Brontide) handleLocalCloseReq(req *htlcswitch.ChanClose) { } chanCloser, err := p.createChanCloser( - channel, deliveryScript, req.TargetFeePerKw, req, true, + channel, deliveryScript, req.TargetFeePerKw, req, + lntypes.Local, ) if err != nil { p.log.Errorf(err.Error())