Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions feature/required.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package feature

import (
"fmt"

"github.com/lightningnetwork/lnd/lnwire"
)

// ErrUnknownRequired signals that a feature vector requires certain features
// that our node is unaware of or does not implement.
type ErrUnknownRequired struct {
unknown []lnwire.FeatureBit
}

// NewErrUnknownRequired initializes an ErrUnknownRequired with the unknown
// feature bits.
func NewErrUnknownRequired(unknown []lnwire.FeatureBit) ErrUnknownRequired {
return ErrUnknownRequired{
unknown: unknown,
}
}

// Error returns a human-readable description of the error.
func (e ErrUnknownRequired) Error() string {
return fmt.Sprintf("feature vector contains unknown required "+
"features: %v", e.unknown)
}

// ValidateRequired returns an error if the feature vector contains a non-zero
// number of unknown, required feature bits.
func ValidateRequired(fv *lnwire.FeatureVector) error {
unknown := fv.UnknownRequiredFeatures()
if len(unknown) > 0 {
return NewErrUnknownRequired(unknown)
}
return nil
}
12 changes: 5 additions & 7 deletions peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2460,7 +2460,7 @@ func (p *peer) handleInitMsg(msg *lnwire.Init) error {
// those presented in the local features fields.
err := msg.Features.Merge(msg.GlobalFeatures)
if err != nil {
return fmt.Errorf("unable to merge legacy global featues: %v",
return fmt.Errorf("unable to merge legacy global features: %v",
err)
}

Expand All @@ -2472,19 +2472,17 @@ func (p *peer) handleInitMsg(msg *lnwire.Init) error {

// Now that we have their features loaded, we'll ensure that they
// didn't set any required bits that we don't know of.
unknownFeatures := p.remoteFeatures.UnknownRequiredFeatures()
if len(unknownFeatures) > 0 {
err := fmt.Errorf("peer set unknown feature bits: %v",
unknownFeatures)
return err
err = feature.ValidateRequired(p.remoteFeatures)
if err != nil {
return fmt.Errorf("invalid remote features: %v", err)
}

// Ensure the remote party's feature vector contains all transistive
// dependencies. We know ours are are correct since they are validated
// during the feature manager's instantiation.
err = feature.ValidateDeps(p.remoteFeatures)
if err != nil {
return fmt.Errorf("peer set invalid feature vector: %v", err)
return fmt.Errorf("invalid remote features: %v", err)
}

// Now that we know we understand their requirements, we'll check to
Expand Down
29 changes: 24 additions & 5 deletions routing/pathfind.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,14 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
}
}

// With the destination's feature vector selected, ensure that all
// transitive depdencies are set.
// Ensure that the destination's features don't include unknown
// required features.
err = feature.ValidateRequired(features)
if err != nil {
return nil, err
}

// Ensure that all transitive dependencies are set.
err = feature.ValidateDeps(features)
if err != nil {
return nil, err
Expand Down Expand Up @@ -752,11 +758,24 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,

// If the node exists and has valid features, use them.
case err == nil:
err := feature.ValidateDeps(targetNode.Features)
if err == nil {
fromFeatures = targetNode.Features
nodeFeatures := targetNode.Features

// Don't route through nodes that contain
// unknown required features.
err = feature.ValidateRequired(nodeFeatures)
if err != nil {
break
}
Comment thread
cfromknecht marked this conversation as resolved.
Outdated

// Don't route through nodes that don't properly
// set all transitive feature dependencies.
err = feature.ValidateDeps(nodeFeatures)
if err != nil {
break
}

fromFeatures = nodeFeatures

// If an error other than the node not existing is hit,
// abort.
case err != channeldb.ErrGraphNodeNotFound:
Expand Down
71 changes: 71 additions & 0 deletions routing/pathfind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ var (
lnwire.PaymentAddrOptional,
), lnwire.Features,
)

unknownRequiredFeatures = lnwire.NewFeatureVector(
lnwire.NewRawFeatureVector(100), lnwire.Features,
)
)

var (
Expand Down Expand Up @@ -1645,6 +1649,73 @@ func TestMissingFeatureDep(t *testing.T) {
}
}

// TestUnknownRequiredFeatures asserts that we fail path finding when the
// destination requires an unknown required feature, and that we skip
// intermediaries that signal unknown required features.
func TestUnknownRequiredFeatures(t *testing.T) {
t.Parallel()

testChannels := []*testChannel{
asymmetricTestChannel("roasbeef", "conner", 100000,
&testChannelPolicy{
Expiry: 144,
FeeRate: 400,
MinHTLC: 1,
MaxHTLC: 100000000,
},
&testChannelPolicy{
Expiry: 144,
FeeRate: 400,
MinHTLC: 1,
MaxHTLC: 100000000,
Features: unknownRequiredFeatures,
}, 0,
),
asymmetricTestChannel("conner", "joost", 100000,
&testChannelPolicy{
Expiry: 144,
FeeRate: 400,
MinHTLC: 1,
MaxHTLC: 100000000,
Features: unknownRequiredFeatures,
},
&testChannelPolicy{
Expiry: 144,
FeeRate: 400,
MinHTLC: 1,
MaxHTLC: 100000000,
}, 0,
),
}

ctx := newPathFindingTestContext(t, testChannels, "roasbeef")
defer ctx.cleanup()

conner := ctx.keyFromAlias("conner")
joost := ctx.keyFromAlias("joost")

// Conner's node in the graph has an unknown required feature (100).
// Pathfinding should fail since we check the destination's features for
// unknown required features before beginning pathfinding.
expErr := feature.NewErrUnknownRequired([]lnwire.FeatureBit{100})
_, err := ctx.findPath(conner, 100)
if !reflect.DeepEqual(err, expErr) {
t.Fatalf("path shouldn't have been found: %v", err)
}

// Now, try to find a route to joost through conner. The destination
// features are valid, but conner's feature vector in the graph still
// requires feature 100. We expect errNoPathFound and not the error
// above since intermediate hops are simply skipped if they have invalid
// feature vectors, leaving no possible route to joost. This asserts
// that we don't try to route _through_ nodes with unknown required
// features.
_, err = ctx.findPath(joost, 100)
if err != errNoPathFound {
t.Fatalf("path shouldn't have been found: %v", err)
}
}

// TestDestPaymentAddr asserts that we properly detect when we can send a
// payment address to a receiver, and also that we fallback to the receiver's
// node announcement if we don't have an invoice features.
Expand Down
29 changes: 2 additions & 27 deletions watchtower/wtwire/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"

"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/feature"
"github.com/lightningnetwork/lnd/lnwire"
)

Expand Down Expand Up @@ -92,12 +93,7 @@ func (msg *Init) CheckRemoteInit(remoteInit *Init,

// Check that the remote peer doesn't have any required connection
// feature bits that we ourselves are unaware of.
unknownConnFeatures := remoteConnFeatures.UnknownRequiredFeatures()
if len(unknownConnFeatures) > 0 {
return NewErrUnknownRequiredFeatures(unknownConnFeatures...)
}

return nil
return feature.ValidateRequired(remoteConnFeatures)
}

// ErrUnknownChainHash signals that the remote Init has a different chain hash
Expand All @@ -116,24 +112,3 @@ func NewErrUnknownChainHash(hash chainhash.Hash) *ErrUnknownChainHash {
func (e *ErrUnknownChainHash) Error() string {
return fmt.Sprintf("remote init has unknown chain hash: %s", e.hash)
}

// ErrUnknownRequiredFeatures signals that the remote Init has required feature
// bits that were unknown to us.
type ErrUnknownRequiredFeatures struct {
unknownFeatures []lnwire.FeatureBit
}

// NewErrUnknownRequiredFeatures creates an ErrUnknownRequiredFeatures using the
// remote Init's required features that were unknown to us.
func NewErrUnknownRequiredFeatures(
unknownFeatures ...lnwire.FeatureBit) *ErrUnknownRequiredFeatures {

return &ErrUnknownRequiredFeatures{unknownFeatures}
}

// Error returns a human-readable error displaying the unknown required feature
// bits.
func (e *ErrUnknownRequiredFeatures) Error() string {
return fmt.Sprintf("remote init has unknown required features: %v",
e.unknownFeatures)
}
5 changes: 3 additions & 2 deletions watchtower/wtwire/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/feature"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
)
Expand Down Expand Up @@ -60,8 +61,8 @@ var checkRemoteInitTests = []checkRemoteInitTest{
lHash: testnetChainHash,
rFeatures: lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired),
rHash: testnetChainHash,
expErr: wtwire.NewErrUnknownRequiredFeatures(
lnwire.GossipQueriesRequired,
expErr: feature.NewErrUnknownRequired(
[]lnwire.FeatureBit{lnwire.GossipQueriesRequired},
),
},
}
Expand Down
9 changes: 1 addition & 8 deletions zpay32/invoice.go
Original file line number Diff line number Diff line change
Expand Up @@ -956,14 +956,7 @@ func parseFeatures(data []byte) (*lnwire.FeatureVector, error) {
return nil, err
}

fv := lnwire.NewFeatureVector(rawFeatures, lnwire.Features)
unknownFeatures := fv.UnknownRequiredFeatures()
if len(unknownFeatures) > 0 {
return nil, fmt.Errorf("invoice contains unknown required "+
"features: %v", unknownFeatures)
}

return fv, nil
return lnwire.NewFeatureVector(rawFeatures, lnwire.Features), nil
}

// writeTaggedFields writes the non-nil tagged fields of the Invoice to the
Expand Down
7 changes: 3 additions & 4 deletions zpay32/invoice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,8 @@ func TestDecodeEncode(t *testing.T) {
{
// On mainnet, please send $30 coffee beans supporting
// features 9, 15, 99, and 100, using secret 0x11...
encodedInvoice: "lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5vdhkven9v5sxyetpdeessp5zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zygs9q4psqqqqqqqqqqqqqqqpqqqqu7fz6pjqczdm3jp3qps7xntj2w2mm70e0ckhw3c5xk9p36pvk3sewn7ncaex6uzfq0vtqzy28se6pcwn790vxex7xystzumhg55p6qq9wq7td",
valid: false,
skipEncoding: true,
encodedInvoice: "lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5vdhkven9v5sxyetpdeessp5zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zygs9q4psqqqqqqqqqqqqqqqpqsqq40wa3khl49yue3zsgm26jrepqr2eghqlx86rttutve3ugd05em86nsefzh4pfurpd9ek9w2vp95zxqnfe2u7ckudyahsa52q66tgzcp6t2dyk",
valid: true,
decodedInvoice: func() *Invoice {
return &Invoice{
Net: &chaincfg.MainNetParams,
Expand Down Expand Up @@ -710,7 +709,7 @@ func TestDecodeEncode(t *testing.T) {
}

if test.valid {
if err := compareInvoices(test.decodedInvoice(), invoice); err != nil {
if err := compareInvoices(decodedInvoice, invoice); err != nil {
t.Errorf("Invoice decoding result %d not as expected: %v", i, err)
return
}
Expand Down