diff --git a/src/embit/descriptor/arguments.py b/src/embit/descriptor/arguments.py index 0c8d4d5..b99309e 100644 --- a/src/embit/descriptor/arguments.py +++ b/src/embit/descriptor/arguments.py @@ -14,7 +14,10 @@ def __init__(self, fingerprint: bytes, derivation: list): @classmethod def from_string(cls, s: str): arr = s.split("/") - mfp = unhexlify(arr[0]) + try: + mfp = unhexlify(arr[0]) + except Exception as err: + raise ArgumentError(str(err)) if len(mfp) != 4: raise ArgumentError("Invalid fingerprint length") arr[0] = "m" diff --git a/src/embit/descriptor/descriptor.py b/src/embit/descriptor/descriptor.py index 28ced7a..f9fde02 100644 --- a/src/embit/descriptor/descriptor.py +++ b/src/embit/descriptor/descriptor.py @@ -27,9 +27,25 @@ def __init__( # - raise if taptree is not None, but taproot=False if key is None and miniscript is None and taptree is None: raise DescriptorError("Provide a key, miniscript or taptree") + + if key is not None and not isinstance(key, Key): + raise DescriptorError("'%s' object is not a Key" % key.__class__.__name__) + if miniscript is not None: # will raise if can't verify - miniscript.verify() + try: + miniscript.verify() + except Exception as err: + raise DescriptorError(str(err)) + + # Chek correctness of miniscript + # (https://github.com/bitcoin/bips/blob/master/bip-0379.md) + # the top level B (base) cover the following cases: + # - older(n), after(n); + # - sha256(h), ripemd160(h), hash256(h), hash160(h); + # - and_b(X,Y), or_b(X,Z), or_d(X,Z); + # - thresh(k,X_1,...,X_n), multi(k,key_1,...,key_n), multi_a(k,key_1,...,key_n); + # - c:X, d:X, j:X, n:X; if miniscript.type != "B": raise DescriptorError("Top level miniscript should be 'B'") # check all branches have the same length @@ -44,10 +60,15 @@ def __init__( self.miniscript = miniscript self.wpkh = wpkh self.taproot = taproot + self.taptree = taptree or TapTree() + # make sure all keys are either taproot or not - for k in self.keys: - k.taproot = taproot + try: + for k in self.keys: + k.taproot = taproot + except Exception as err: + raise DescriptorError(str(err)) @property def script_len(self): diff --git a/tests/tests/test_descriptor.py b/tests/tests/test_descriptor.py index de3e4a6..1bc0ed8 100644 --- a/tests/tests/test_descriptor.py +++ b/tests/tests/test_descriptor.py @@ -1,14 +1,62 @@ from unittest import TestCase from binascii import hexlify from embit.descriptor import Descriptor, Key +from embit.descriptor import miniscript from embit.descriptor.arguments import KeyHash, Number from embit.descriptor.miniscript import OPERATORS, WRAPPERS -from embit.descriptor.errors import MiniscriptError +from embit.descriptor.errors import ArgumentError, MiniscriptError from embit.descriptor.checksum import add_checksum, DescriptorError from embit import ec class DescriptorTest(TestCase): + + def test_fail_initialization(self): + """ + Tests the initialization of a Descriptor instance + initialization, where key, miniscript or taptree + arguments None. + """ + cases = [ + (None, None, None, "Provide a key, miniscript or taptree"), + ("key", None, None, "'str' object is not a Key"), + (None, "miniscript", None, "'str' object has no attribute 'verify'"), + (None, None, "taptree", "'str' object has no attribute 'keys'"), + ] + + for case in cases: + with self.assertRaises(DescriptorError) as exc: + Descriptor(key=case[0], miniscript=case[1], taptree=case[2]) + self.assertEqual(str(exc.exception), case[3]) + + def test_fail_wrong_fingerprint_length(self): + """Tests that a descriptor with a wrong fingerprint raises an error""" + cases = [ + # A valid fingerprint for the test key is c1684a69 + # but we use a wrong one here that is c1684a69a (9 characters instead of 8) + ( + "wpkh([c1684a69a/84'/1'/0']tpubDDY2HTrz5YTJGe4dejjxgiiuex6Gfu7Ca21zEkCf7GcWpPhpM172yt9aeJqWg5zD7n6gUFfnFJeyMokc54rCQ9tWLjs9VaZHxV95g6RYjf5/1/*)#2zrpegx5", + "Odd-length string", + ), + # A valid fingerprint for the test key is c1684a69 + # but we added a bytes at the end to make it invalid (a0) + ( + "wpkh([c1684a69a0/84'/1'/0']tpubDDY2HTrz5YTJGe4dejjxgiiuex6Gfu7Ca21zEkCf7GcWpPhpM172yt9aeJqWg5zD7n6gUFfnFJeyMokc54rCQ9tWLjs9VaZHxV95g6RYjf5/1/*)#2zrpegx5", + "Invalid fingerprint length", + ), + # A valid fingerprint for the test key is c1684a69 + # but we use a wrong one here that is c1684axj (xj characters instead of 69) + ( + "wpkh([c1684axj/84'/1'/0']tpubDDY2HTrz5YTJGe4dejjxgiiuex6Gfu7Ca21zEkCf7GcWpPhpM172yt9aeJqWg5zD7n6gUFfnFJeyMokc54rCQ9tWLjs9VaZHxV95g6RYjf5/1/*)#2zrpegx5", + "Non-hexadecimal digit found", + ), + ] + + for case in cases: + with self.assertRaises(ArgumentError) as exc: + Descriptor.from_string(case[0]) + self.assertEqual(str(exc.exception), case[1]) + def test_desc_checksum(self): """Tests descriptor checksums""" desc = ( diff --git a/tests/tests/test_taproot.py b/tests/tests/test_taproot.py index 3e3ad64..86746b4 100644 --- a/tests/tests/test_taproot.py +++ b/tests/tests/test_taproot.py @@ -7,6 +7,7 @@ from embit.psbtview import PSBTView from embit.ec import SchnorrSig, PublicKey from embit.transaction import SIGHASH +from embit.descriptor.errors import ArgumentError from io import BytesIO from binascii import unhexlify @@ -131,6 +132,15 @@ def test_invalid(self): "wpkh(b4ca2da5380d9aeb5ca67e4f18c487ae9b668748517e12b788496f63765e2efa)" ) + def test_fail_taproot_tweak(self): + desc = "wpkh([abcdef12/84h/22h]xpub6F6wWxm8F64iBHNhyaoh3QKCuuMUY5pfPPr1H1WuZXUXeXtZ21qjFN5ykaqnLL1jtPEFB9d94CyZrcYWKVdSiJKQ6mLGEB5sfrGFBpg6wgA/<0;1>/*)" + non_tr_desc = Descriptor.from_string(desc) + key = non_tr_desc.keys[0] + with self.assertRaises(Exception) as exc: + assert not getattr(key, "taproot", False) + key.taproot_tweak() + self.assertEqual(str(exc.exception), "Key is not taproot") + def test_sign_verify(self): unsigned = PSBT.from_string(B64PSBT) signed = PSBT.from_string(B64SIGNED)