diff --git a/programs/zstdcli.c b/programs/zstdcli.c index 29da261dfbd..d589616c406 100644 --- a/programs/zstdcli.c +++ b/programs/zstdcli.c @@ -312,23 +312,56 @@ static void errorOut(const char* msg) DISPLAY("%s \n", msg); exit(1); } +/*! nextCharValidCheck() : + * @return 0 if next char maybe valid. + * next char is necessarily one of the accepted arguments. + * This method does not assure absolute validity of the argument, but covers + * a large portion of wrong cases. + * @return 1 if next char is not an accepted character. +*/ +static int nextCharValidCheck(const char** stringPtr) { + char alphabet[32] = "VHhzdcDfvqkCtoMlTspP,"; + size_t alphaNb = 0; + int invalid = 1; +#ifdef UTIL_HAS_CREATEFILELIST + strcat(alphabet, "r"); +#endif +#ifndef ZSTD_NOBENCH + strcat(alphabet, "beiBS"); +#endif + + for ( ; invalid && alphaNb < strlen(alphabet); alphaNb++ ) { + if (**stringPtr == alphabet[alphaNb]) { + invalid = 0; + } + } + if (**stringPtr == 0) { + invalid = 0; + } + + return invalid; +} + /*! readU32FromCharChecked() : * @return 0 if success, and store the result in *value. * allows and interprets K, KB, KiB, M, MB and MiB suffix. * Will also modify `*stringPtr`, advancing it to position where it stopped reading. - * @return 1 if an overflow error occurs */ + * @return 1 if an overflow error occurs, 2 if the number is in a wrong format */ static int readU32FromCharChecked(const char** stringPtr, unsigned* value) { unsigned result = 0; + int valid = 0; while ((**stringPtr >='0') && (**stringPtr <='9')) { unsigned const max = ((unsigned)(-1)) / 10; unsigned last = result; + valid = 1; if (result > max) return 1; /* overflow error */ result *= 10; result += (unsigned)(**stringPtr - '0'); if (result < last) return 1; /* overflow error */ (*stringPtr)++ ; } + if (!valid && **stringPtr == 0) return 2; /* wrong format error */ if ((**stringPtr=='K') || (**stringPtr=='M')) { unsigned const maxK = ((unsigned)(-1)) >> 10; if (result > maxK) return 1; /* overflow error */ @@ -341,6 +374,7 @@ static int readU32FromCharChecked(const char** stringPtr, unsigned* value) if (**stringPtr=='i') (*stringPtr)++; if (**stringPtr=='B') (*stringPtr)++; } + if (!valid || nextCharValidCheck(stringPtr)) return 2; /* wrong format error */ *value = result; return 0; } @@ -351,9 +385,12 @@ static int readU32FromCharChecked(const char** stringPtr, unsigned* value) * Will also modify `*stringPtr`, advancing it to position where it stopped reading. * Note : function will exit() program if digit sequence overflows */ static unsigned readU32FromChar(const char** stringPtr) { - static const char errorMsg[] = "error: numeric value overflows 32-bit unsigned int"; + static const char overflowErrorMsg[] = "error: numeric value overflows 32-bit unsigned int"; + static const char formatErrorMsg[] = "error: wrong number format"; unsigned result; - if (readU32FromCharChecked(stringPtr, &result)) { errorOut(errorMsg); } + int checkResult = readU32FromCharChecked(stringPtr, &result); + if (checkResult == 1) { errorOut(overflowErrorMsg); } + else if (checkResult == 2) { errorOut(formatErrorMsg); } return result; } @@ -363,14 +400,18 @@ static unsigned readU32FromChar(const char** stringPtr) { * Will also modify `*stringPtr`, advancing it to position where it stopped reading. * Note : function will exit() program if digit sequence overflows */ static int readIntFromChar(const char** stringPtr) { - static const char errorMsg[] = "error: numeric value overflows 32-bit int"; + static const char overflowErrorMsg[] = "error: numeric value overflows 32-bit int"; + static const char formatErrorMsg[] = "error: wrong number format"; int sign = 1; unsigned result; + int checkResult; if (**stringPtr=='-') { (*stringPtr)++; sign = -1; } - if (readU32FromCharChecked(stringPtr, &result)) { errorOut(errorMsg); } + checkResult = readU32FromCharChecked(stringPtr, &result); + if (checkResult == 1) { errorOut(overflowErrorMsg); } + else if (checkResult == 2) { errorOut(formatErrorMsg); } return (int) result * sign; } @@ -378,19 +419,22 @@ static int readIntFromChar(const char** stringPtr) { * @return 0 if success, and store the result in *value. * allows and interprets K, KB, KiB, M, MB and MiB suffix. * Will also modify `*stringPtr`, advancing it to position where it stopped reading. - * @return 1 if an overflow error occurs */ + * @return 1 if an overflow error occurs, 2 if the number is in a wrong format */ static int readSizeTFromCharChecked(const char** stringPtr, size_t* value) { size_t result = 0; + int valid = 0; while ((**stringPtr >='0') && (**stringPtr <='9')) { size_t const max = ((size_t)(-1)) / 10; size_t last = result; + valid = 1; if (result > max) return 1; /* overflow error */ result *= 10; result += (size_t)(**stringPtr - '0'); if (result < last) return 1; /* overflow error */ - (*stringPtr)++ ; + (*stringPtr)++; } + if (!valid && **stringPtr == 0) return 2; /* wrong format error */ if ((**stringPtr=='K') || (**stringPtr=='M')) { size_t const maxK = ((size_t)(-1)) >> 10; if (result > maxK) return 1; /* overflow error */ @@ -403,6 +447,7 @@ static int readSizeTFromCharChecked(const char** stringPtr, size_t* value) if (**stringPtr=='i') (*stringPtr)++; if (**stringPtr=='B') (*stringPtr)++; } + if (!valid || nextCharValidCheck(stringPtr)) return 2; /* wrong format error */ *value = result; return 0; } @@ -413,9 +458,14 @@ static int readSizeTFromCharChecked(const char** stringPtr, size_t* value) * Will also modify `*stringPtr`, advancing it to position where it stopped reading. * Note : function will exit() program if digit sequence overflows */ static size_t readSizeTFromChar(const char** stringPtr) { - static const char errorMsg[] = "error: numeric value overflows size_t"; + static const char overflowErrorMsg[] = "error: numeric value overflows size_t"; + static const char formatErrorMsg[] = "error: wrong number format"; size_t result; - if (readSizeTFromCharChecked(stringPtr, &result)) { errorOut(errorMsg); } + int checkResult; + + checkResult = readSizeTFromCharChecked(stringPtr, &result); + if (checkResult == 1) { errorOut(overflowErrorMsg); } + else if (checkResult == 2) { errorOut(formatErrorMsg); } return result; } @@ -988,9 +1038,9 @@ int main(int argCount, const char* argv[]) } #endif if (longCommandWArg(&argument, "--threads")) { NEXT_UINT32(nbWorkers); continue; } - if (longCommandWArg(&argument, "--memlimit")) { NEXT_UINT32(memLimit); continue; } - if (longCommandWArg(&argument, "--memory")) { NEXT_UINT32(memLimit); continue; } - if (longCommandWArg(&argument, "--memlimit-decompress")) { NEXT_UINT32(memLimit); continue; } + if (longCommandWArg(&argument, "--memlimit=")) { memLimit = readU32FromChar(&argument); continue; } + if (longCommandWArg(&argument, "--memory=")) { memLimit = readU32FromChar(&argument); continue; } + if (longCommandWArg(&argument, "--memlimit-decompress=")) { memLimit = readU32FromChar(&argument); continue; } if (longCommandWArg(&argument, "--block-size=")) { blockSize = readSizeTFromChar(&argument); continue; } if (longCommandWArg(&argument, "--maxdict")) { NEXT_UINT32(maxDictSize); continue; } if (longCommandWArg(&argument, "--dictID")) { NEXT_UINT32(dictID); continue; } diff --git a/tests/playTests.sh b/tests/playTests.sh index 71e8dc05818..40a2975a860 100755 --- a/tests/playTests.sh +++ b/tests/playTests.sh @@ -275,6 +275,13 @@ zstd -d -f tmplimit.zst --memlimit=2K -c > $INTOVOID && die "decompression needs zstd -d -f tmplimit.zst --memory=2K -c > $INTOVOID && die "decompression needs more memory than allowed" # long command zstd -d -f tmplimit.zst --memlimit-decompress=2K -c > $INTOVOID && die "decompression needs more memory than allowed" # long command rm -f tmplimit tmplimit.zst +println foo > tmpmemory.zst +println "test : zstd parameter parsing (must fail)" +zstd -d -f tmpmemory.zst --memory= -c > $INTOVOID && die "memory parameter is in an invalid format" +zstd -d -f tmpmemory.zst --memory=hello -c > $INTOVOID && die "memory parameter is in an invalid format" +zstd -d -f tmpmemory.zst --memlimit=512LB -c > $INTOVOID && die "memory parameter is in an invalid format" +zstd -d -f tmpmemory.zst --memlimit-decompress=512LiB -c > $INTOVOID && die "memory parameter is in an invalid format" +rm tmpmemory.zst println "test : overwrite protection" zstd -q tmp && die "overwrite check failed!" println "test : force overwrite"