Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 65 additions & 16 deletions construct-stubs/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ from construct.lib import (
ListType,
RebufferedBytesIO,
)
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers.aead import AESCCM, AESGCM, ChaCha20Poly1305
from cryptography.hazmat.primitives.ciphers.modes import Mode
from typing_extensions import Buffer

# unfortunately, there are a few duplications with "typing", e.g. Union and Optional, which is why the t. prefix must be used everywhere
Expand Down Expand Up @@ -67,6 +70,7 @@ class RawCopyError(ConstructError): ...
class RotationError(ConstructError): ...
class ChecksumError(ConstructError): ...
class CancelParsing(ConstructError): ...
class CipherError(ConstructError): ...

# ===============================================================================
# used internally
Expand All @@ -86,6 +90,17 @@ def stream_size(stream: StreamType) -> int: ...
def stream_iseof(stream: StreamType) -> bool: ...
def evaluate(param: ConstantOrContextLambda2[T], context: Context) -> T: ...

class BytesIOWithOffsets(io.BytesIO):
@staticmethod
def from_reading(
stream: StreamType, length: int, path: PathType
) -> BytesIOWithOffsets: ...
def __init__(
self, contents: bytes, parent_stream: StreamType, offset: int
) -> None: ...
def tell(self) -> int: ...
def seek(self, offset: int, whence: int = ...) -> int: ...

# ===============================================================================
# abstract constructs
# ===============================================================================
Expand Down Expand Up @@ -135,12 +150,19 @@ class Construct(t.Generic[ParsedType, BuildTypes]):
) -> Renamed[ParsedType, BuildTypes]: ...
def __add__(self, other: Construct[t.Any, t.Any]) -> Struct: ...
def __rshift__(self, other: Construct[t.Any, t.Any]) -> Sequence: ...
def __getitem__(
self, count: t.Union[int, t.Callable[[Context], int]]
) -> Array[ParsedType, BuildTypes,]: ...
def _parse(self, stream: StreamType, context: Context, path: PathType) -> ParsedType: ...
def _parsereport(self, stream: StreamType, context: Context, path: PathType) -> ParsedType: ...
def _build(self, obj: BuildTypes, stream: StreamType, context: Context, path: PathType) -> int: ...
def __getitem__(self, count: t.Union[int, t.Callable[[Context], int]]) -> Array[
ParsedType,
BuildTypes,
]: ...
def _parse(
self, stream: StreamType, context: Context, path: PathType
) -> ParsedType: ...
def _parsereport(
self, stream: StreamType, context: Context, path: PathType
) -> ParsedType: ...
def _build(
self, obj: BuildTypes, stream: StreamType, context: Context, path: PathType
) -> int: ...
def _sizeof(self, context: Context, path: PathType) -> int: ...

@t.type_check_only
Expand Down Expand Up @@ -234,15 +256,11 @@ class Bytes(Construct[bytes, t.Union[bytes, bytearray, int]]):

GreedyBytes: Construct[bytes, t.Union[bytes, bytearray]]

def Bitwise(
subcon: Construct[SubconParsedType, SubconBuildTypes]
) -> t.Union[
def Bitwise(subcon: Construct[SubconParsedType, SubconBuildTypes]) -> t.Union[
Transformed[SubconParsedType, SubconBuildTypes],
Restreamed[SubconParsedType, SubconBuildTypes],
]: ...
def Bytewise(
subcon: Construct[SubconParsedType, SubconBuildTypes]
) -> t.Union[
def Bytewise(subcon: Construct[SubconParsedType, SubconBuildTypes]) -> t.Union[
Transformed[SubconParsedType, SubconBuildTypes],
Restreamed[SubconParsedType, SubconBuildTypes],
]: ...
Expand Down Expand Up @@ -880,6 +898,16 @@ class Peek(
subcon: Construct[SubconParsedType, SubconBuildTypes],
) -> None: ...

class OffsettedEnd(
Subconstruct[SubconParsedType, SubconBuildTypes, SubconParsedType, SubconBuildTypes]
):
endoffset: ConstantOrContextLambda[int]
def __init__(
self,
endoffset: ConstantOrContextLambda[int],
subcon: Construct[SubconParsedType, SubconBuildTypes],
) -> None: ...

class Seek(Construct[int, None]):
at: ConstantOrContextLambda[int]
if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -924,9 +952,7 @@ class RawCopy(
def ByteSwapped(
subcon: Construct[SubconParsedType, SubconBuildTypes]
) -> Transformed[SubconParsedType, SubconBuildTypes]: ...
def BitsSwapped(
subcon: Construct[SubconParsedType, SubconBuildTypes]
) -> t.Union[
def BitsSwapped(subcon: Construct[SubconParsedType, SubconBuildTypes]) -> t.Union[
Transformed[SubconParsedType, SubconBuildTypes],
Restreamed[SubconParsedType, SubconBuildTypes],
]: ...
Expand All @@ -946,7 +972,10 @@ class Prefixed(
def PrefixedArray(
countfield: Construct[int, int],
subcon: Construct[SubconParsedType, SubconBuildTypes],
) -> Array[SubconParsedType, SubconBuildTypes,]: ...
) -> Array[
SubconParsedType,
SubconBuildTypes,
]: ...

class FixedSized(
Subconstruct[SubconParsedType, SubconBuildTypes, SubconParsedType, SubconBuildTypes]
Expand Down Expand Up @@ -1095,6 +1124,26 @@ class Rebuffered(
tailcutoff: t.Optional[int] = ...,
) -> None: ...

class EncryptedSym(Tunnel[SubconParsedType, SubconBuildTypes]):
cipher: ConstantOrContextLambda2[Cipher[Mode]]
def __init__(
self,
subcon: Construct[SubconParsedType, SubconBuildTypes],
cipher: ConstantOrContextLambda2[Cipher[Mode]],
) -> None: ...

class EncryptedSymAead(Tunnel[SubconParsedType, SubconBuildTypes]):
cipher: ConstantOrContextLambda2[t.Union[AESGCM, AESCCM, ChaCha20Poly1305]]
nonce: ConstantOrContextLambda2[bytes]
associated_data: ConstantOrContextLambda2[bytes]
def __init__(
self,
subcon: Construct[SubconParsedType, SubconBuildTypes],
cipher: ConstantOrContextLambda2[t.Union[AESGCM, AESCCM, ChaCha20Poly1305]],
nonce: ConstantOrContextLambda2[bytes],
associated_data: ConstantOrContextLambda2[bytes] = ...,
) -> None: ...

# ===============================================================================
# lazy equivalents
# ===============================================================================
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
construct==2.10.68
construct==2.10.70
pytest>=6.2.0
numpy
arrow
Expand All @@ -7,4 +7,5 @@ cloudpickle
lz4
black
isort
mypy
mypy
cryptography
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
author="Tim Riddermann",
python_requires=">=3.7",
install_requires=[
"construct==2.10.68",
"construct==2.10.70",
"typing_extensions>=4.6.0"
],
keywords=[
Expand Down
138 changes: 134 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,29 @@ def test_formatfield_bool_issue_901() -> None:
assert d.sizeof() == 1

def test_bytesinteger() -> None:
d = BytesInteger(0)
assert raises(d.parse, b"") == IntegerError
assert raises(d.build, 0) == IntegerError
d = BytesInteger(4, signed=True, swapped=False)
common(d, b"\x01\x02\x03\x04", 0x01020304, 4)
common(d, b"\xff\xff\xff\xff", -1, 4)
d = BytesInteger(4, signed=False, swapped=this.swapped)
common(d, b"\x01\x02\x03\x04", 0x01020304, 4, swapped=False)
common(d, b"\x04\x03\x02\x01", 0x01020304, 4, swapped=True)
assert raises(BytesInteger(-1).parse, b"") == IntegerError
assert raises(BytesInteger(-1).build, 0) == IntegerError
assert raises(BytesInteger(8).build, None) == IntegerError
assert raises(BytesInteger(8, signed=False).build, -1) == IntegerError
assert raises(BytesInteger(8, True).build, -2**64) == IntegerError
assert raises(BytesInteger(8, True).build, 2**64) == IntegerError
assert raises(BytesInteger(8, False).build, -2**64) == IntegerError
assert raises(BytesInteger(8, False).build, 2**64) == IntegerError
assert raises(BytesInteger(this.missing).sizeof) == SizeofError
assert raises(BytesInteger(4, signed=False).build, -1) == IntegerError
common(BytesInteger(0), b"", 0, 0)

def test_bitsinteger() -> None:
d = BitsInteger(0)
assert raises(d.parse, b"") == IntegerError
assert raises(d.build, 0) == IntegerError
d = BitsInteger(8)
common(d, b"\x01\x01\x01\x01\x01\x01\x01\x01", 255, 8)
d = BitsInteger(8, signed=True)
Expand All @@ -171,9 +183,17 @@ def test_bitsinteger() -> None:
d = BitsInteger(16, swapped=this.swapped)
common(d, b"\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00", 0xff00, 16, swapped=False)
common(d, b"\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x01", 0xff00, 16, swapped=True)
assert raises(BitsInteger(this.missing).sizeof) == SizeofError
assert raises(BitsInteger(-1).parse, b"") == IntegerError
assert raises(BitsInteger(-1).build, 0) == IntegerError
assert raises(BitsInteger(5, swapped=True).parse, bytes(5)) == IntegerError
assert raises(BitsInteger(5, swapped=True).build, 0) == IntegerError
assert raises(BitsInteger(8).build, None) == IntegerError
assert raises(BitsInteger(8, signed=False).build, -1) == IntegerError
common(BitsInteger(0), b"", 0, 0)
assert raises(BitsInteger(8, True).build, -2**64) == IntegerError
assert raises(BitsInteger(8, True).build, 2**64) == IntegerError
assert raises(BitsInteger(8, False).build, -2**64) == IntegerError
assert raises(BitsInteger(8, False).build, 2**64) == IntegerError
assert raises(BitsInteger(this.missing).sizeof) == SizeofError

def test_varint() -> None:
d = VarInt
Expand Down Expand Up @@ -926,6 +946,17 @@ def test_peek() -> None:
assert d4.build(Container(a=0x01, b=0x0102)) == b""
assert d4.sizeof() == 0

def test_offsettedend() -> None:
d1 = Struct(
"header" / Bytes(2),
"data" / OffsettedEnd(-2, GreedyBytes),
"footer" / Bytes(2),
)
common(d1, b"\x01\x02\x03\x04\x05\x06\x07", Container(header=b'\x01\x02', data=b'\x03\x04\x05', footer=b'\x06\x07'))

d2 = OffsettedEnd(0, Byte)
assert raises(d2.sizeof) == SizeofError

def test_seek() -> None:
d = Seek(5)
assert d.parse(b"") == 5
Expand Down Expand Up @@ -1334,6 +1365,105 @@ def test_compressed_prefixed() -> None:
assert st.parse(st.build(Container(one=zeros,two=zeros))) == Container(one=zeros,two=zeros)
assert raises(d.sizeof) == SizeofError

@pytest.mark.xfail(ONWINDOWS and PYPY, reason="no wheel for 'cryptography' is currently available for pypy on windows")
def test_encryptedsym() -> None:
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
key128 = b"\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"
key256 = b"\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"
iv = b"\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x2a\x2b\x2c\x2d\x2e\x2f"
nonce = iv

# AES 128/256 bit - ECB
d = EncryptedSym(GreedyBytes, lambda ctx: Cipher(algorithms.AES(ctx.key), modes.ECB()))
common(d, b"\xf4\x0f\x54\xb7\x6a\x7a\xf1\xdb\x92\x73\x14\xde\x2f\xa0\x3e\x2d", b'Secret Message..', key=key128, iv=iv)
common(d, b"\x82\x6b\x01\x82\x90\x02\xa1\x9e\x35\x0a\xe2\xc3\xee\x1a\x42\xf5", b'Secret Message..', key=key256, iv=iv)

# AES 128/256 bit - CBC
d = EncryptedSym(GreedyBytes, lambda ctx: Cipher(algorithms.AES(ctx.key), modes.CBC(ctx.iv)))
common(d, b"\xba\x79\xc2\x62\x22\x08\x29\xb9\xfb\xd3\x90\xc4\x04\xb7\x55\x87", b'Secret Message..', key=key128, iv=iv)
common(d, b"\x60\xc2\x45\x0d\x7e\x41\xd4\xf8\x85\xd4\x8a\x64\xd1\x45\x49\xe3", b'Secret Message..', key=key256, iv=iv)

# AES 128/256 bit - CTR
d = EncryptedSym(GreedyBytes, lambda ctx: Cipher(algorithms.AES(ctx.key), modes.CTR(ctx.nonce)))
common(d, b"\x80\x78\xb6\x0c\x07\xf5\x0c\x90\xce\xa2\xbf\xcb\x5b\x22\xb9\xb5", b'Secret Message..', key=key128, nonce=nonce)
common(d, b"\x6a\xae\x7b\x86\x1a\xa6\xe0\x6a\x49\x02\x02\x1b\xf2\x3c\xd8\x0d", b'Secret Message..', key=key256, nonce=nonce)

assert raises(EncryptedSym(GreedyBytes, "AES").build, b"") == CipherError # type: ignore
assert raises(EncryptedSym(GreedyBytes, "AES").parse, b"") == CipherError # type: ignore

@pytest.mark.xfail(ONWINDOWS and PYPY, reason="no wheel for 'cryptography' is currently available for pypy on windows")
def test_encryptedsym_cbc_example() -> None:
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
d = Struct(
"iv" / Default(Bytes(16), os.urandom(16)),
"enc_data" / EncryptedSym(
Aligned(16,
Struct(
"width" / Int16ul,
"height" / Int16ul
)
),
lambda ctx: Cipher(algorithms.AES(ctx._.key), modes.CBC(ctx.iv))
)
)
key128 = b"\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"
byts = d.build({"enc_data": {"width": 5, "height": 4}}, key=key128)
obj = d.parse(byts, key=key128)
assert obj.enc_data == Container(width=5, height=4)

@pytest.mark.xfail(ONWINDOWS and PYPY, reason="no wheel for 'cryptography' is currently available for pypy on windows")
def test_encryptedsymaead() -> None:
from cryptography.hazmat.primitives.ciphers import aead
key128 = b"\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"
key256 = b"\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"
nonce = b"\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x2a\x2b\x2c\x2d\x2e\x2f"

# AES 128/256 bit - GCM
d = Struct(
"associated_data" / Bytes(21),
"data" / EncryptedSymAead(
GreedyBytes,
lambda ctx: aead.AESGCM(ctx._.key),
this._.nonce,
this.associated_data
)
)
common(
d,
b"This is authenticated\xb6\xd3\x64\x0c\x7a\x31\xaa\x16\xa3\x58\xec\x17\x39\x99\x2e\xf8\x4e\x41\x17\x76\x3f\xd1\x06\x47\x04\x9f\x42\x1c\xf4\xa9\xfd\x99\x9c\xe9",
Container(associated_data=b"This is authenticated", data=b"The secret message"),
key=key128,
nonce=nonce
)
common(
d,
b"This is authenticated\xde\xb4\x41\x79\xc8\x7f\xea\x8d\x0e\x41\xf6\x44\x2f\x93\x21\xe6\x37\xd1\xd3\x29\xa4\x97\xc3\xb5\xf4\x81\x72\xa1\x7f\x3b\x9b\x53\x24\xe4",
Container(associated_data=b"This is authenticated", data=b"The secret message"),
key=key256,
nonce=nonce
)
assert raises(EncryptedSymAead(GreedyBytes, "AESGCM", bytes(16)).build, b"") == CipherError # type: ignore
assert raises(EncryptedSymAead(GreedyBytes, "AESGCM", bytes(16)).parse, b"") == CipherError # type: ignore

@pytest.mark.xfail(ONWINDOWS and PYPY, reason="no wheel for 'cryptography' is currently available for pypy on windows")
def test_encryptedsymaead_gcm_example() -> None:
from cryptography.hazmat.primitives.ciphers import aead
d = Struct(
"nonce" / Default(Bytes(16), os.urandom(16)),
"associated_data" / Bytes(21),
"enc_data" / EncryptedSymAead(
GreedyBytes,
lambda ctx: aead.AESGCM(ctx._.key),
this.nonce,
this.associated_data
)
)
key128 = b"\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"
byts = d.build({"associated_data": b"This is authenticated", "enc_data": b"The secret message"}, key=key128)
obj = d.parse(byts, key=key128)
assert obj.enc_data == b"The secret message"
assert obj.associated_data == b"This is authenticated"

def test_rebuffered() -> None:
data = b"0" * 1000
assert Rebuffered(Array(1000,Byte)).parse_stream(io.BytesIO(data)) == [48]*1000
Expand Down