Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion bip380/descriptors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)

from .checksum import descsum_create
from .errors import DescriptorParsingError
from .parsing import descriptor_from_str


Expand All @@ -21,7 +22,22 @@ def from_str(desc_str, strict=False):

:param strict: whether to require the presence of a checksum.
"""
return descriptor_from_str(desc_str, strict)
desc = descriptor_from_str(desc_str, strict)

# BIP389 prescribes that no two multipath key expressions in a single descriptor
# have different length.
multipath_len = None
for key in desc.keys:
if key.is_multipath():
m_len = len(key.path.paths)
if multipath_len is None:
multipath_len = m_len
elif multipath_len != m_len:
raise DescriptorParsingError(
f"Descriptor contains multipath key expressions with varying length: '{desc_str}'."
)

return desc

@property
def script_pubkey(self):
Expand Down Expand Up @@ -60,6 +76,41 @@ def satisfy(self, *args, **kwargs):
# To be implemented by derived classes
raise NotImplementedError

def copy(self):
"""Get a copy of this descriptor."""
# FIXME: do something nicer than roundtripping through string ser
return Descriptor.from_str(str(self))

def is_multipath(self):
"""Whether this descriptor contains multipath key expression(s)."""
return any(k.is_multipath() for k in self.keys)

def singlepath_descriptors(self):
"""Get a list of descriptors that only contain keys that don't have multiple
derivation paths.
"""
singlepath_descs = [self.copy()]

# First figure out the number of descriptors there will be
for key in self.keys:
if key.is_multipath():
singlepath_descs += [self.copy() for _ in range(len(key.path.paths) - 1)]
break

# Return early if there was no multipath key expression
if len(singlepath_descs) == 1:
return singlepath_descs

# Then use one path for each
for i, desc in enumerate(singlepath_descs):
for key in desc.keys:
if key.is_multipath():
assert len(key.path.paths) == len(singlepath_descs)
key.path.paths = key.path.paths[i : i + 1]

assert all(not d.is_multipath() for d in singlepath_descs)
return singlepath_descs


# TODO: add methods to give access to all the Miniscript analysis
class WshDescriptor(Descriptor):
Expand Down
165 changes: 124 additions & 41 deletions bip380/key.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from bip32 import BIP32
import copy

from bip32 import BIP32, HARDENED_INDEX
from bip32.utils import coincurve, _deriv_path_str_to_list
from bip380.utils.hashes import hash160
from enum import Enum, auto
Expand Down Expand Up @@ -58,48 +60,83 @@ def is_wildcard(self):
return self in [KeyPathKind.WILDCARD_HARDENED, KeyPathKind.WILDCARD_UNHARDENED]


def parse_index(index_str):
"""Parse a derivation index, as contained in a derivation path."""
assert isinstance(index_str, str)

try:
# if HARDENED
if index_str[-1:] in ["'", "h", "H"]:
return int(index_str[:-1]) + HARDENED_INDEX
else:
return int(index_str)
except ValueError as e:
raise DescriptorKeyError(f"Invalid derivation index {index_str}: '{e}'")


class DescriptorKeyPath:
"""The derivation path of a key in a descriptor.

See https://github.com/bitcoin/bips/blob/master/bip-0380.mediawiki#key-expressions.
See https://github.com/bitcoin/bips/blob/master/bip-0380.mediawiki#key-expressions
as well as BIP389 for multipath expressions.
"""

def __init__(self, path, kind):
assert isinstance(path, list) and isinstance(kind, KeyPathKind)
def __init__(self, paths, kind):
assert (
isinstance(paths, list)
and isinstance(kind, KeyPathKind)
and len(paths) > 0
and all(isinstance(p, list) for p in paths)
)

self.path = path
self.paths = paths
self.kind = kind

def is_multipath(self):
"""Whether this derivation path actually contains multiple of them."""
return len(self.paths) > 1

def from_str(path_str):
if len(path_str) < 1:
if len(path_str) < 2:
raise DescriptorKeyError(f"Insane key path: '{path_str}'")
if path_str[0] == "/":
if path_str[0] != "/":
raise DescriptorKeyError(f"Insane key path: '{path_str}'")

# Determine whether this key may be derived.
kind = KeyPathKind.FINAL
if path_str[-2:] in ["*'", "*h", "*H"]:
if len(path_str) > 2 and path_str[-3:] in ["/*'", "/*h", "/*H"]:
kind = KeyPathKind.WILDCARD_HARDENED
path_str = path_str[:-2]
elif path_str[-1] == "*":
path_str = path_str[:-3]
elif len(path_str) > 1 and path_str[-2:] == "/*":
kind = KeyPathKind.WILDCARD_UNHARDENED
path_str = path_str[:-1]

# We use an internal helper from python-bip32 to parse the path.
# The helper operates on "m/10h/11/12'/13", so give it a "m/".
if len(path_str) > 1:
dummy = "m/"
# If we just trimmed the wildcard part, time the trailing '/' too.
if kind.is_wildcard():
path_str = path_str[:-1]
try:
path = _deriv_path_str_to_list(dummy + path_str)
except ValueError:
raise DescriptorKeyError(f"Insane path in key path: '{path_str}'")
else:
path = []
path_str = path_str[:-2]

return DescriptorKeyPath(path, kind)
paths = [[]]
if len(path_str) == 0:
return DescriptorKeyPath(paths, kind)

for index in path_str[1:].split("/"):
# If this is a multipath expression, of the form '<X;X>'
if (
index.startswith("<")
and index.endswith(">")
and ";" in index
and len(index) >= 5
):
# Can't have more than one multipath expression
if len(paths) > 1:
raise DescriptorKeyError(
f"May only have a single multipath step in derivation path: '{path_str}'"
)
indexes = index[1:-1].split(";")
paths = [copy.copy(paths[0]) for _ in indexes]
for i, der_index in enumerate(indexes):
paths[i].append(parse_index(der_index))
else:
# This is a "single index" expression.
for path in paths:
path.append(parse_index(index))
return DescriptorKeyPath(paths, kind)


class DescriptorKey:
Expand Down Expand Up @@ -144,7 +181,7 @@ def __init__(self, key):
splitted_key = key.split("/", maxsplit=1)
if len(splitted_key) == 2:
key, path = splitted_key
self.path = DescriptorKeyPath.from_str(path)
self.path = DescriptorKeyPath.from_str("/" + path)

try:
self.key = BIP32.from_xpub(key)
Expand All @@ -159,17 +196,33 @@ def __init__(self, key):
def __repr__(self):
key = ""

def ser_path(key, path):
for i in path:
if i < 2**31:
key += f"/{i}"
def ser_index(key, der_index):
# If this a hardened step, deduce the threshold and mark it.
if der_index < HARDENED_INDEX:
return str(der_index)
else:
return f"{der_index - 2**31}'"

def ser_paths(key, paths):
assert len(paths) > 0

for i, der_index in enumerate(paths[0]):
# If this is a multipath expression, write the multi-index step accordingly
if len(paths) > 1 and paths[1][i] != der_index:
key += "/<"
for j, path in enumerate(paths):
key += ser_index(key, path[i])
if j < len(paths) - 1:
key += ";"
key += ">"
else:
key += f"/{i - 2**31}'"
key += "/" + ser_index(key, der_index)

return key

if self.origin is not None:
key += f"[{self.origin.fingerprint.hex()}"
key = ser_path(key, self.origin.path)
key = ser_paths(key, [self.origin.path])
key += "]"

if isinstance(self.key, BIP32):
Expand All @@ -179,30 +232,59 @@ def ser_path(key, path):
key += self.key.format().hex()

if self.path is not None:
key = ser_path(key, self.path.path)
if self.path.kind.is_wildcard():
key = ser_paths(key, self.path.paths)
if self.path.kind == KeyPathKind.WILDCARD_UNHARDENED:
key += "/*"
elif self.path.kind == KeyPathKind.WILDCARD_HARDENED:
key += "/*'"

return key

def is_multipath(self):
"""Whether this key contains more than one derivation path."""
return self.path is not None and self.path.is_multipath()

def derivation_path(self):
"""Get the single derivation path for this key.

Will raise if it has multiple, and return None if it doesn't have any.
"""
if self.path is None:
return None
if self.path.is_multipath():
raise DescriptorKeyError(
f"Key has multiple derivation paths: {self.path.paths}"
)
return self.path.paths[0]

def bytes(self):
"""Get this key as raw bytes.

Will raise if this key contains multiple derivation paths.
"""
if isinstance(self.key, coincurve.PublicKey):
return self.key.format()
else:
assert isinstance(self.key, BIP32)
if self.path is None or self.path.path == []:
path = self.derivation_path()
if path is None:
return self.key.pubkey
assert not self.path.kind.is_wildcard() # TODO: real errors
return self.key.get_pubkey_from_path(self.path.path)
return self.key.get_pubkey_from_path(path)

def derive(self, index):
"""Derive the key at the given index.

Will raise if this key contains multiple derivation paths.
A no-op if the key isn't a wildcard. Will start from 2**31 if the key is a "hardened
wildcard".
"""
assert isinstance(index, int)
if self.path is None or self.path.kind == KeyPathKind.FINAL:
if (
self.path is None
or self.path.is_multipath()
or self.path.kind == KeyPathKind.FINAL
):
return
assert isinstance(self.key, BIP32)

Expand All @@ -215,8 +297,9 @@ def derive(self, index):
self.origin = DescriporKeyOrigin(fingerprint, [index])
else:
self.origin.path.append(index)

# This can't fail now.
path = self.derivation_path()
# TODO(bip32): have a way to derive without roundtripping through string ser.
self.key = BIP32.from_xpub(
self.key.get_xpub_from_path(self.path.path + [index])
)
self.key = BIP32.from_xpub(self.key.get_xpub_from_path(path + [index]))
self.path = None
Loading