From cfd39873277b07e17343350b00de32025e37c8c4 Mon Sep 17 00:00:00 2001 From: IAL32 Date: Tue, 1 Mar 2022 14:57:10 +0000 Subject: [PATCH 1/3] More checks on numeric input --- programs/zstdcli.c | 74 ++++++++++++++++++++++++++++++++++++++-------- tests/playTests.sh | 7 +++++ 2 files changed, 69 insertions(+), 12 deletions(-) diff --git a/programs/zstdcli.c b/programs/zstdcli.c index 29da261dfbd..9eddcb5d889 100644 --- a/programs/zstdcli.c +++ b/programs/zstdcli.c @@ -312,23 +312,56 @@ static void errorOut(const char* msg) DISPLAY("%s \n", msg); exit(1); } +/*! nextCharValidCheck() : + * @return 0 if next char maybe valid. + * next char is necessarily one of the accepted arguments. + * This method does not assure absolute validity of the argument, but covers + * a large portion of wrong cases. + * @return 1 if next char is not an accepted character. +*/ +static int nextCharValidCheck(const char** stringPtr) { + char alphabet[32] = "VHhzdcDfvqkCtoMlTspP,"; + size_t alphaNb = 0; + int invalid = 1; +#ifdef UTIL_HAS_CREATEFILELIST + strcat(alphabet, "r"); +#endif +#ifndef ZSTD_NOBENCH + strcat(alphabet, "beiBS"); +#endif + + for ( ; invalid && alphaNb < strlen(alphabet); alphaNb++ ) { + if (**stringPtr == alphabet[alphaNb]) { + invalid = 0; + } + } + if (**stringPtr == 0) { + invalid = 0; + } + + return invalid; +} + /*! readU32FromCharChecked() : * @return 0 if success, and store the result in *value. * allows and interprets K, KB, KiB, M, MB and MiB suffix. * Will also modify `*stringPtr`, advancing it to position where it stopped reading. - * @return 1 if an overflow error occurs */ + * @return 1 if an overflow error occurs, 2 if the number is in a wrong format */ static int readU32FromCharChecked(const char** stringPtr, unsigned* value) { unsigned result = 0; + int valid = 0; while ((**stringPtr >='0') && (**stringPtr <='9')) { unsigned const max = ((unsigned)(-1)) / 10; unsigned last = result; + valid = 1; if (result > max) return 1; /* overflow error */ result *= 10; result += (unsigned)(**stringPtr - '0'); if (result < last) return 1; /* overflow error */ (*stringPtr)++ ; } + if (!valid && **stringPtr == 0) return 2; /* wrong format error */ if ((**stringPtr=='K') || (**stringPtr=='M')) { unsigned const maxK = ((unsigned)(-1)) >> 10; if (result > maxK) return 1; /* overflow error */ @@ -341,6 +374,7 @@ static int readU32FromCharChecked(const char** stringPtr, unsigned* value) if (**stringPtr=='i') (*stringPtr)++; if (**stringPtr=='B') (*stringPtr)++; } + if (!valid || nextCharValidCheck(stringPtr)) return 2; /* wrong format error */ *value = result; return 0; } @@ -351,9 +385,12 @@ static int readU32FromCharChecked(const char** stringPtr, unsigned* value) * Will also modify `*stringPtr`, advancing it to position where it stopped reading. * Note : function will exit() program if digit sequence overflows */ static unsigned readU32FromChar(const char** stringPtr) { - static const char errorMsg[] = "error: numeric value overflows 32-bit unsigned int"; + static const char overflowErrorMsg[] = "error: numeric value overflows 32-bit unsigned int"; + static const char formatErrorMsg[] = "error: wrong number format"; unsigned result; - if (readU32FromCharChecked(stringPtr, &result)) { errorOut(errorMsg); } + int checkResult = readU32FromCharChecked(stringPtr, &result); + if (checkResult == 1) { errorOut(overflowErrorMsg); } + else if (checkResult == 2) { errorOut(formatErrorMsg); } return result; } @@ -363,14 +400,18 @@ static unsigned readU32FromChar(const char** stringPtr) { * Will also modify `*stringPtr`, advancing it to position where it stopped reading. * Note : function will exit() program if digit sequence overflows */ static int readIntFromChar(const char** stringPtr) { - static const char errorMsg[] = "error: numeric value overflows 32-bit int"; + static const char overflowErrorMsg[] = "error: numeric value overflows 32-bit int"; + static const char formatErrorMsg[] = "error: wrong number format"; int sign = 1; unsigned result; + int checkResult; if (**stringPtr=='-') { (*stringPtr)++; sign = -1; } - if (readU32FromCharChecked(stringPtr, &result)) { errorOut(errorMsg); } + checkResult = readU32FromCharChecked(stringPtr, &result); + if (checkResult == 1) { errorOut(overflowErrorMsg); } + else if (checkResult == 2) { errorOut(formatErrorMsg); } return (int) result * sign; } @@ -378,19 +419,22 @@ static int readIntFromChar(const char** stringPtr) { * @return 0 if success, and store the result in *value. * allows and interprets K, KB, KiB, M, MB and MiB suffix. * Will also modify `*stringPtr`, advancing it to position where it stopped reading. - * @return 1 if an overflow error occurs */ + * @return 1 if an overflow error occurs, 2 if the number is in a wrong format */ static int readSizeTFromCharChecked(const char** stringPtr, size_t* value) { size_t result = 0; + int valid = 0; while ((**stringPtr >='0') && (**stringPtr <='9')) { size_t const max = ((size_t)(-1)) / 10; size_t last = result; + valid = 1; if (result > max) return 1; /* overflow error */ result *= 10; result += (size_t)(**stringPtr - '0'); if (result < last) return 1; /* overflow error */ - (*stringPtr)++ ; + (*stringPtr)++; } + if (!valid && **stringPtr == 0) return 2; /* wrong format error */ if ((**stringPtr=='K') || (**stringPtr=='M')) { size_t const maxK = ((size_t)(-1)) >> 10; if (result > maxK) return 1; /* overflow error */ @@ -403,6 +447,7 @@ static int readSizeTFromCharChecked(const char** stringPtr, size_t* value) if (**stringPtr=='i') (*stringPtr)++; if (**stringPtr=='B') (*stringPtr)++; } + if (!valid || nextCharValidCheck(stringPtr)) return 2; /* wrong format error */ *value = result; return 0; } @@ -413,9 +458,14 @@ static int readSizeTFromCharChecked(const char** stringPtr, size_t* value) * Will also modify `*stringPtr`, advancing it to position where it stopped reading. * Note : function will exit() program if digit sequence overflows */ static size_t readSizeTFromChar(const char** stringPtr) { - static const char errorMsg[] = "error: numeric value overflows size_t"; + static const char overflowErrorMsg[] = "error: numeric value overflows size_t"; + static const char formatErrorMsg[] = "error: wrong number format"; size_t result; - if (readSizeTFromCharChecked(stringPtr, &result)) { errorOut(errorMsg); } + int checkResult; + + checkResult = readSizeTFromCharChecked(stringPtr, &result); + if (checkResult == 1) { errorOut(overflowErrorMsg); } + else if (checkResult == 2) { errorOut(formatErrorMsg); } return result; } @@ -988,9 +1038,9 @@ int main(int argCount, const char* argv[]) } #endif if (longCommandWArg(&argument, "--threads")) { NEXT_UINT32(nbWorkers); continue; } - if (longCommandWArg(&argument, "--memlimit")) { NEXT_UINT32(memLimit); continue; } - if (longCommandWArg(&argument, "--memory")) { NEXT_UINT32(memLimit); continue; } - if (longCommandWArg(&argument, "--memlimit-decompress")) { NEXT_UINT32(memLimit); continue; } + if (longCommandWArg(&argument, "--memlimit=")) { memLimit = readSizeTFromChar(&argument); continue; } + if (longCommandWArg(&argument, "--memory=")) { memLimit = readSizeTFromChar(&argument); continue; } + if (longCommandWArg(&argument, "--memlimit-decompress=")) { memLimit = readSizeTFromChar(&argument); continue; } if (longCommandWArg(&argument, "--block-size=")) { blockSize = readSizeTFromChar(&argument); continue; } if (longCommandWArg(&argument, "--maxdict")) { NEXT_UINT32(maxDictSize); continue; } if (longCommandWArg(&argument, "--dictID")) { NEXT_UINT32(dictID); continue; } diff --git a/tests/playTests.sh b/tests/playTests.sh index 71e8dc05818..40a2975a860 100755 --- a/tests/playTests.sh +++ b/tests/playTests.sh @@ -275,6 +275,13 @@ zstd -d -f tmplimit.zst --memlimit=2K -c > $INTOVOID && die "decompression needs zstd -d -f tmplimit.zst --memory=2K -c > $INTOVOID && die "decompression needs more memory than allowed" # long command zstd -d -f tmplimit.zst --memlimit-decompress=2K -c > $INTOVOID && die "decompression needs more memory than allowed" # long command rm -f tmplimit tmplimit.zst +println foo > tmpmemory.zst +println "test : zstd parameter parsing (must fail)" +zstd -d -f tmpmemory.zst --memory= -c > $INTOVOID && die "memory parameter is in an invalid format" +zstd -d -f tmpmemory.zst --memory=hello -c > $INTOVOID && die "memory parameter is in an invalid format" +zstd -d -f tmpmemory.zst --memlimit=512LB -c > $INTOVOID && die "memory parameter is in an invalid format" +zstd -d -f tmpmemory.zst --memlimit-decompress=512LiB -c > $INTOVOID && die "memory parameter is in an invalid format" +rm tmpmemory.zst println "test : overwrite protection" zstd -q tmp && die "overwrite check failed!" println "test : force overwrite" From 8173d14e17aff2d74d00bf1cb32a5fedc0f95b3b Mon Sep 17 00:00:00 2001 From: IAL32 Date: Tue, 1 Mar 2022 15:06:40 +0000 Subject: [PATCH 2/3] `memLimit` is now of type `size_t` --- programs/zstdcli.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/programs/zstdcli.c b/programs/zstdcli.c index 9eddcb5d889..6b65cfa5708 100644 --- a/programs/zstdcli.c +++ b/programs/zstdcli.c @@ -871,7 +871,7 @@ int main(int argCount, const char* argv[]) int cLevel = init_cLevel(); int cLevelLast = MINCLEVEL - 1; /* lower than minimum */ unsigned recursive = 0; - unsigned memLimit = 0; + size_t memLimit = 0; FileNamesTable* filenames = UTIL_allocateFileNamesTable((size_t)argCount); /* argCount >= 1 */ FileNamesTable* file_of_names = UTIL_allocateFileNamesTable((size_t)argCount); /* argCount >= 1 */ const char* programName = argv[0]; From cc1833b2b8a8b29e13dbe06481726e11b8aebab1 Mon Sep 17 00:00:00 2001 From: IAL32 Date: Tue, 1 Mar 2022 15:13:07 +0000 Subject: [PATCH 3/3] Revert `memLimit` from `size_t` to `unsigned` Changed read method instead --- programs/zstdcli.c | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/programs/zstdcli.c b/programs/zstdcli.c index 6b65cfa5708..d589616c406 100644 --- a/programs/zstdcli.c +++ b/programs/zstdcli.c @@ -871,7 +871,7 @@ int main(int argCount, const char* argv[]) int cLevel = init_cLevel(); int cLevelLast = MINCLEVEL - 1; /* lower than minimum */ unsigned recursive = 0; - size_t memLimit = 0; + unsigned memLimit = 0; FileNamesTable* filenames = UTIL_allocateFileNamesTable((size_t)argCount); /* argCount >= 1 */ FileNamesTable* file_of_names = UTIL_allocateFileNamesTable((size_t)argCount); /* argCount >= 1 */ const char* programName = argv[0]; @@ -1038,9 +1038,9 @@ int main(int argCount, const char* argv[]) } #endif if (longCommandWArg(&argument, "--threads")) { NEXT_UINT32(nbWorkers); continue; } - if (longCommandWArg(&argument, "--memlimit=")) { memLimit = readSizeTFromChar(&argument); continue; } - if (longCommandWArg(&argument, "--memory=")) { memLimit = readSizeTFromChar(&argument); continue; } - if (longCommandWArg(&argument, "--memlimit-decompress=")) { memLimit = readSizeTFromChar(&argument); continue; } + if (longCommandWArg(&argument, "--memlimit=")) { memLimit = readU32FromChar(&argument); continue; } + if (longCommandWArg(&argument, "--memory=")) { memLimit = readU32FromChar(&argument); continue; } + if (longCommandWArg(&argument, "--memlimit-decompress=")) { memLimit = readU32FromChar(&argument); continue; } if (longCommandWArg(&argument, "--block-size=")) { blockSize = readSizeTFromChar(&argument); continue; } if (longCommandWArg(&argument, "--maxdict")) { NEXT_UINT32(maxDictSize); continue; } if (longCommandWArg(&argument, "--dictID")) { NEXT_UINT32(dictID); continue; }