diff --git a/.gitignore b/.gitignore index 28bf1c4..cbb6d70 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ +*.o serve serve_* serve.gcno logs/log_*.txt *.info -report \ No newline at end of file +logs +report +blacklist.txt diff --git a/Makefile b/Makefile index b1cdea1..0851df0 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,5 @@ +.PHONY: all report cppcheck gcc-analyzer clang-tidy + all: serve_debug serve_cov serve serve_debug: serve.c @@ -13,3 +15,13 @@ report: lcov --capture --directory . --output-file coverage.info @ mkdir -p report genhtml coverage.info --output-directory report + +cppcheck: + cppcheck -j1 --enable=portability serve.c + cppcheck -j1 --enable=style serve.c + +gcc-analyzer: + gcc -c -fanalyzer serve.c + +clang-tidy: + clang-tidy serve.c \ No newline at end of file diff --git a/README.md b/README.md index 63b583c..17fe465 100644 --- a/README.md +++ b/README.md @@ -2,16 +2,26 @@ This is a minimal web server designed to serve my blog. I'm writing it to be robust enough to face the public internet. You can see it in action at http://playin.coz.is/index.html. You probably can't get it to crash, but feel free to try! And if you manage to do it, send me an email to show off! I'll leave the coolest attempts in `attempts.txt`. # Specs -- Linux only -- HTTP/1.1 + pipelining + keep-alive -- Request and connection deadlines +- Only runs on Linux +- HTTP/1.1 support with pipelining and keep-alive +- Uses request and connection timeouts +- IP blacklist +- Access log, log file rotation, hard disk usage limits - No `Transfer-Encoding: Chunked` (when receiving a chunked request the server responds with `411 Length Required`, prompting the client to try again with the `Content-Length` header) - Static file serving utilities - Single core (This will probably change when I get a better VPS) -# How I test it -I'm still testing manually. I usually stress test the server locally using `wrk` and see if it breaks. I also test it under `valgrind` and sanitizers. +# Testing +I routinely run the server under valgrind or sanitizers (address, undefined) and target it using `wrk`. I'm also adding automatized tests to `tests/test.py` to check compliance with the HTTP/1.1 spec. + +# Blocking IPs +To block any number of IP addresses you need to create a `blacklist.txt` file and insert the IPs. You can also add comments: +``` +# I'm a comment +10.0.0.1 +127.0.0.1 # I'm a comment too +``` +Blocked addresses will be rejected after being accepted. This is just a best effort solution as you should block connections using iptables or nftables. # Known Issues - Server replies to HTTP/1.0 clients as HTTP/1.1 -- Since poll is edge triggered, when the server is full and can't accept all new connections the remaining ones are left waiting until some other event wakes up poll() diff --git a/serve.c b/serve.c index 18a14ff..1fbe605 100644 --- a/serve.c +++ b/serve.c @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -19,19 +20,30 @@ #ifdef RELEASE #define PORT 80 +#define MAX_CONNECTIONS 1024 +#define LOG_BUFFER_SIZE (1<<20) +#define LOG_FILE_LIMIT (1<<24) #define LOG_DIRECTORY_SIZE_LIMIT_MB (25 * 1024) #else #define PORT 8080 -#define LOG_DIRECTORY_SIZE_LIMIT_MB 10 +#define MAX_CONNECTIONS 32 +#define LOG_BUFFER_SIZE (1<<10) +#define LOG_FILE_LIMIT (1<<20) +#define LOG_DIRECTORY_SIZE_LIMIT_MB 100 #endif +#define LOG_DIRECTORY "logs" + +#define BLACKLIST 1 +#define BLACKLIST_FILE "blacklist.txt" +#define BLACKLIST_LIMIT 1024 + +#define ACCESS_LOG 1 #define SHOW_IO 0 -#define SHOW_REQUESTS 1 +#define SHOW_REQUESTS 0 #define REQUEST_TIMEOUT_SEC 5 #define CLOSING_TIMEOUT_SEC 2 #define CONNECTION_TIMEOUT_SEC 60 -#define LOG_BUFFER_SIZE (1<<20) -#define LOG_BUFFER_LIMIT (1<<24) #define LOG_FLUSH_TIMEOUT_SEC 3 #define INPUT_BUFFER_LIMIT_MB 1 @@ -41,7 +53,7 @@ #define DEBUG(...) {} #endif -static_assert(LOG_BUFFER_SIZE < LOG_BUFFER_LIMIT, ""); +static_assert(LOG_BUFFER_SIZE < LOG_FILE_LIMIT, ""); typedef struct { char *data; @@ -65,12 +77,13 @@ void log_format(const char *fmt, ...); void log_flush(void); bool log_empty(void); -uint64_t get_current_time_ms(void) -{ - struct timespec ts; - int ret = clock_gettime(CLOCK_MONOTONIC, &ts); - if (ret) log_fatal(LIT("Couldn't read time\n")); +#if BLACKLIST +bool ip_allowed(uint32_t ip); +bool load_blacklist(void); +#endif +uint64_t timespec_to_ms(struct timespec ts) +{ if ((uint64_t) ts.tv_sec > UINT64_MAX / 1000) log_fatal(LIT("Time overflow\n")); uint64_t ms = ts.tv_sec * 1000; @@ -82,12 +95,25 @@ uint64_t get_current_time_ms(void) return ms; } +uint64_t get_monotonic_time_ms(void) +{ + struct timespec ts; + int ret = clock_gettime(CLOCK_MONOTONIC, &ts); + if (ret) log_fatal(LIT("Couldn't read monotonic time\n")); + return timespec_to_ms(ts); +} + +uint64_t get_real_time_ms(void) +{ + struct timespec ts; + int ret = clock_gettime(CLOCK_REALTIME, &ts); + if (ret) log_fatal(LIT("Couldn't read real time\n")); + return timespec_to_ms(ts); +} + void *mymalloc(size_t num) { - int x = 1; // rand(); - if (x & 1) - return malloc(num); - return NULL; + return malloc(num); } /////////////////////////////////////////////////////////////////////////////////////// @@ -162,6 +188,7 @@ bool endswith(string suffix, string name) return suffix.size <= name.size && !memcmp(tail, suffix.data, suffix.size); } +// TODO: Make sure every string in request is reasonaly long int parse_request_head(string str, Request *request) { char *src = str.data; @@ -327,7 +354,7 @@ int parse_request_head(string str, Request *request) request->nheaders++; } } - cur += 2; // \r\n + // cur here points to the \r in \r\n return P_OK; } @@ -717,6 +744,7 @@ void print_bytes(string prefix, string str) typedef struct { ByteQueue input; ByteQueue output; + uint32_t ipaddr; int served_count; bool closing; bool keep_alive; @@ -804,6 +832,7 @@ void add_header_f(ResponseBuilder *b, const char *fmt, ...) bool should_keep_alive(Connection *conn); uint64_t now; +uint64_t real_now; void append_special_headers(ResponseBuilder *b) { @@ -928,8 +957,6 @@ void response_builder_complete(ResponseBuilder *b) void respond(Request request, ResponseBuilder *b); -#define MAX_CONNECTIONS 1024 - struct pollfd pollarray[MAX_CONNECTIONS+1]; Connection conns[MAX_CONNECTIONS]; int num_conns = 0; @@ -937,28 +964,20 @@ int num_conns = 0; bool should_keep_alive(Connection *conn) { // Don't keep alive if the peer doesn't want to - if (conn->keep_alive == false) { - DEBUG("Not keeping alive because peer wants to close\n"); + if (conn->keep_alive == false) return false; - } // Don't keep alive if the request is too old - if (now - conn->creation_time > CONNECTION_TIMEOUT_SEC * 1000) { - DEBUG("Not keeping alive because the connection is too old\n"); + if (now - conn->creation_time > CONNECTION_TIMEOUT_SEC * 1000) return false; - } // Don't keep alive if we served a lot of requests to this connection - if (conn->served_count > 100) { - DEBUG("Not keeping alive because too many requests were served using this connection\n"); + if (conn->served_count > 100) return false; - } // Don't keep alive if the server is more than 70% full - if (num_conns > 0.7 * MAX_CONNECTIONS) { - DEBUG("Not keeping alive because the server is more than 70%% full\n"); + if (num_conns > 0.7 * MAX_CONNECTIONS) return false; - } return true; } @@ -1000,8 +1019,42 @@ bool respond_to_available_requests(struct pollfd *polldata, Connection *conn) #endif // Found! We got the request head + Request request; int res = parse_request_head((string) {src.data, head_length}, &request); + +#if ACCESS_LOG + { + // Log access + time_t real_now_in_secs = real_now / 1000; + struct tm timeinfo; + localtime_r(&real_now_in_secs, &timeinfo); + char timebuf[128]; + size_t timelen = strftime(timebuf, sizeof(timebuf), "%Y/%m/%d %H:%M:%S", &timeinfo); + if (timelen == 0) + log_fatal(LIT("Couldn't format time for access log")); + timebuf[timelen] = '\0'; + + char ipbuf[INET_ADDRSTRLEN]; + const char *ipstr = inet_ntop(AF_INET, &conn->ipaddr, ipbuf, sizeof(ipbuf)); + if (ipstr == NULL) + log_fatal(LIT("Couldn't format IP address for access log")); + + if (res == P_OK) { + string user_agent; + if (!find_header(&request, LIT("User-Agent"), &user_agent)) + user_agent = LIT("No User-Agent"); + else + user_agent = trim(user_agent); + log_format("%s - %s - %.*s - %.*s\n", timebuf, ipstr, + (int) request.path.size, request.path.data, + (int) user_agent.size, user_agent.data); + } else { + log_format("%s - %s - Bad request\n", timebuf, ipstr); + } + } +#endif + if (res != P_OK) { // Invalid HTTP request byte_queue_write(&conn->output, LIT( @@ -1079,17 +1132,9 @@ bool respond_to_available_requests(struct pollfd *polldata, Connection *conn) conn->keep_alive = false; string keep_alive_header; if (find_header(&request, LIT("Connection"), &keep_alive_header)) { - DEBUG("Found Connection header\n"); - if (string_match_case_insensitive(trim(keep_alive_header), LIT("Keep-Alive"))) { + if (string_match_case_insensitive(trim(keep_alive_header), LIT("Keep-Alive"))) conn->keep_alive = true; - DEBUG("Matched Keep-Alive (%.*s)\n", (int) trim(keep_alive_header).size, trim(keep_alive_header).data); - } else { - DEBUG("Didn't match Keep-Alive (%.*s)\n", (int) trim(keep_alive_header).size, trim(keep_alive_header).data); - } - } else { - DEBUG("No Connection header\n"); } - // Respond ResponseBuilder builder; response_builder_init(&builder, conn); @@ -1113,11 +1158,11 @@ bool respond_to_available_requests(struct pollfd *polldata, Connection *conn) pipeline_count++; if (pipeline_count == 10) { // TODO: We should send a response to the client instead of dropping it + log_data(LIT("Pipeline limit reached")); remove = true; break; } } - } return remove; @@ -1142,6 +1187,7 @@ bool read_from_socket(int fd, ByteQueue *queue) continue; if (errno == EAGAIN || errno == EWOULDBLOCK) break; + log_perror(LIT("recv")); remove = true; break; } @@ -1226,6 +1272,13 @@ int main(int argc, char **argv) return -1; } +#if BLACKLIST + if (!load_blacklist()) { + log_data(LIT("Couldn't load blacklist\n")); + return -1; + } +#endif + int one = 1; setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, (char*) &one, sizeof(one)); @@ -1247,45 +1300,61 @@ int main(int argc, char **argv) pollarray[0].events = POLLIN; uint64_t last_log_time = 0; - - int timeout = -1; + bool pending_accept = false; + int timeout = log_empty() ? -1 : LOG_FLUSH_TIMEOUT_SEC * 1000; while (!stop) { - DEBUG("timeout=%d\n", timeout); - int ret = poll(pollarray, MAX_CONNECTIONS, timeout); if (ret < 0) { if (errno == EINTR) - break; + break; // TODO: Should this be continue? log_perror(LIT("poll")); return -1; } - now = get_current_time_ms(); + now = get_monotonic_time_ms(); + real_now = get_real_time_ms(); + + if (pollarray[0].revents || pending_accept) { + + pending_accept = false; - if (pollarray[0].revents) { for (;;) { // Look for a connection structure - int free_index = 0; - while (free_index < MAX_CONNECTIONS && pollarray[free_index].fd != -1) + int free_index = 1; + while (free_index-1 < MAX_CONNECTIONS && pollarray[free_index].fd != -1) free_index++; - if (free_index == MAX_CONNECTIONS) { + if (free_index-1 == MAX_CONNECTIONS) { pollarray[0].events &= ~POLLIN; // Stop listening for incoming connections + pending_accept = true; break; } - int accepted_fd = accept(listen_fd, NULL, NULL); + struct sockaddr_in accepted_addr; + socklen_t accepted_addrlen = sizeof(accepted_addr); + int accepted_fd = accept(listen_fd, (struct sockaddr*) &accepted_addr, &accepted_addrlen); if (accepted_fd < 0) { if (errno == EINTR) continue; if (errno == EAGAIN || errno == EWOULDBLOCK) break; log_perror(LIT("accept")); + close(accepted_fd); break; } + +#if BLACKLIST + if (!ip_allowed((uint32_t) accepted_addr.sin_addr.s_addr)) { + log_data(LIT("Connection Rejected\n")); + close(accepted_fd); + continue; + } +#endif + if (!set_blocking(accepted_fd, false)) { log_perror(LIT("fcntl")); + close(accepted_fd); continue; } @@ -1294,6 +1363,7 @@ int main(int argc, char **argv) pollarray[free_index].revents = 0; byte_queue_init(&conns[free_index-1].input); byte_queue_init(&conns[free_index-1].output); + conns[free_index-1].ipaddr = (uint32_t) accepted_addr.sin_addr.s_addr; conns[free_index-1].closing = false; conns[free_index-1].served_count = 0; conns[free_index-1].creation_time = now; @@ -1338,7 +1408,7 @@ int main(int argc, char **argv) } } - } else if (!remove && (pollarray[i].revents & POLLIN)) { + } else if (!remove && (pollarray[i].revents & (POLLIN | POLLHUP | POLLERR))) { remove = read_from_socket(pollarray[i].fd, &conn->input); if (!remove) @@ -1365,6 +1435,8 @@ int main(int argc, char **argv) conn->start_time = -1; conn->closing = false; conn->creation_time = 0; + if ((pollarray[0].events & POLLIN) == 0) + pollarray[0].events |= POLLIN; num_conns--; } else { if (oldest == NULL || deadline_of(oldest) > deadline_of(conn)) oldest = conn; @@ -1379,7 +1451,9 @@ int main(int argc, char **argv) /* * Calculate the timeout for the next poll */ - if (log_empty()) { + if (pending_accept && num_conns < MAX_CONNECTIONS) + timeout = 0; + else if (log_empty()) { if (oldest == NULL) timeout = -1; else { @@ -1563,8 +1637,10 @@ int match_path_format(string path, char *fmt, ...) assert(p_stack[i].size > 0); if (f_stack[i].data[0] == ':') { - if (f_stack[i].size != 2) + if (f_stack[i].size != 2) { + va_end(args); return -1; // Invalid format + } switch (f_stack[i].data[1]) { case 'l': @@ -1580,26 +1656,35 @@ int match_path_format(string path, char *fmt, ...) size_t cur = 0; while (cur < p_stack[i].size && is_digit(p_stack[i].data[cur])) { int d = p_stack[i].data[cur] - '0'; - if (n > (UINT32_MAX - d) / 10) + if (n > (UINT32_MAX - d) / 10) { + va_end(args); return -1; // Overflow + } n = n * 10 + d; cur++; } - if (cur != p_stack[i].size) + if (cur != p_stack[i].size) { + va_end(args); return -1; // Component isn't a number + } uint32_t *p = va_arg(args, uint32_t*); *p = n; } break; default: + va_end(args); return -1; // Invalid formt } } else { - if (f_stack[i].size != p_stack[i].size) + if (f_stack[i].size != p_stack[i].size) { + va_end(args); return 1; // No match - if (memcmp(f_stack[i].data, p_stack[i].data, f_stack[i].size)) - return false; + } + if (memcmp(f_stack[i].data, p_stack[i].data, f_stack[i].size)) { + va_end(args); + return 1; + } } } @@ -1772,11 +1857,12 @@ size_t log_buffer_used = 0; bool log_failed = false; size_t log_total_size = 0; -void log_choose_file_name(char *dst, size_t max) +void log_choose_file_name(char *dst, size_t max, bool startup) { + size_t prev_size = -1; for (;;) { - int num = snprintf(dst, max, "logs/log_%d.txt", log_last_file_index); + int num = snprintf(dst, max, LOG_DIRECTORY "/log_%d.txt", log_last_file_index); if (num < 0 || (size_t) num >= max) { fprintf(stderr, "log_failed (%s:%d)\n", __FILE__, __LINE__); log_failed = true; @@ -1785,8 +1871,13 @@ void log_choose_file_name(char *dst, size_t max) dst[num] = '\0'; struct stat buf; - if (stat(dst, &buf) && errno == ENOENT) - break; + if (stat(dst, &buf)) { + if (errno == ENOENT) + break; + prev_size = -1; + } else { + prev_size = (size_t) buf.st_size; + } if (log_last_file_index == 100000000) { fprintf(stderr, "log_failed (%s:%d)\n", __FILE__, __LINE__); @@ -1795,6 +1886,20 @@ void log_choose_file_name(char *dst, size_t max) } log_last_file_index++; } + + // At startup don't create a new log file if the last one didn't reache its limit + if (startup && prev_size < LOG_FILE_LIMIT) { + + log_last_file_index--; + + int num = snprintf(dst, max, LOG_DIRECTORY "/log_%d.txt", log_last_file_index); + if (num < 0 || (size_t) num >= max) { + fprintf(stderr, "log_failed (%s:%d)\n", __FILE__, __LINE__); + log_failed = true; + return; + } + dst[num] = '\0'; + } } void log_init(void) @@ -1808,12 +1913,16 @@ void log_init(void) return; } - char name[1<<12]; - log_choose_file_name(name, sizeof(name)); - if (log_failed) { + if (mkdir(LOG_DIRECTORY, 0666) && errno != EEXIST) { + fprintf(stderr, "log_failed (%s:%d)\n", __FILE__, __LINE__); + log_failed = true; return; } + char name[1<<12]; + log_choose_file_name(name, sizeof(name), true); + if (log_failed) return; + log_fd = open(name, O_WRONLY | O_APPEND | O_CREAT, 0644); if (log_fd < 0) { fprintf(stderr, "log_failed (%s:%d)\n", __FILE__, __LINE__); @@ -1823,7 +1932,7 @@ void log_init(void) log_total_size = 0; - DIR *d = opendir("logs"); + DIR *d = opendir(LOG_DIRECTORY); if (d == NULL) { fprintf(stderr, "log_failed (%s:%d)\n", __FILE__, __LINE__); log_failed = true; @@ -1837,7 +1946,7 @@ void log_init(void) continue; char path[1<<12]; - int k = snprintf(path, SIZEOF(path), "logs/%s", dir->d_name); + int k = snprintf(path, SIZEOF(path), LOG_DIRECTORY "/%s", dir->d_name); if (k < 0 || k >= SIZEOF(path)) log_fatal(LIT("Bad format")); path[k] = '\0'; @@ -1851,7 +1960,7 @@ void log_init(void) } closedir(d); - static_assert(SIZEOF(size_t) > 4); + static_assert(SIZEOF(size_t) > 4, "It's assumed size_t can store a number of bytes in the order of 10gb"); if (log_total_size > (size_t) LOG_DIRECTORY_SIZE_LIMIT_MB * 1024 * 1024) { fprintf(stderr, "Log reached disk limit at startup\n"); log_failed = true; @@ -1880,12 +1989,8 @@ bool log_empty(void) void log_flush(void) { - DEBUG("flushing...\n"); - - if (log_failed || log_buffer_used == 0) { - DEBUG("flush ignored\n"); + if (log_failed || log_buffer_used == 0) return; - } /* * Rotate the file if the limit was reached @@ -1896,9 +2001,9 @@ void log_flush(void) log_failed = true; return; } - if (buf.st_size + log_buffer_used >= LOG_BUFFER_LIMIT) { + if (buf.st_size + log_buffer_used >= LOG_FILE_LIMIT) { char name[1<<12]; - log_choose_file_name(name, SIZEOF(name)); + log_choose_file_name(name, SIZEOF(name), false); if (log_failed) return; close(log_fd); @@ -1946,8 +2051,6 @@ void log_flush(void) } } - DEBUG("flushed %ld bytes\n", log_buffer_used); - assert(copied == log_buffer_used); log_buffer_used = 0; } @@ -2014,13 +2117,122 @@ void log_data(string str) } assert(str.size <= LOG_BUFFER_SIZE - log_buffer_used); + assert(log_buffer); memcpy(log_buffer + log_buffer_used, str.data, str.size); log_buffer_used += str.size; - - DEBUG("logged %ld bytes: [%.*s]\n", str.size, (int) str.size, str.data); } void log_perror(string str) { log_format("%.*s: %s\n", (int) str.size, str.data, strerror(errno)); } + +#if BLACKLIST + +uint32_t blocked_ips[BLACKLIST_LIMIT]; +int blocked_num = 0; + +bool ip_allowed(uint32_t ip) +{ + for (int i = 0; i < blocked_num; i++) + if (ip == blocked_ips[i]) + return false; + return true; +} +bool load_blacklist(void) +{ + int fd = open(BLACKLIST_FILE, O_RDONLY); + if (fd < 0) { + if (errno == ENOENT) + return true; + return false; + } + + struct stat buf; + if (fstat(fd, &buf) || !S_ISREG(buf.st_mode)) { + log_data(LIT("Couldn't stat file or it's not a regular file")); + close(fd); + return false; + } + size_t size = (size_t) buf.st_size; + + char *str = malloc(size); + if (str == NULL) { + log_data(LIT("out of memory")); + close(fd); + return false; + } + + size_t copied = 0; + while (copied < size) { + int n = read(fd, str + copied, size - copied); + if (n < 0) { + if (errno == EINTR) + continue; + log_perror(LIT("read")); + close(fd); + return false; + } + if (n == 0) + break; // EOF + copied += n; + } + + blocked_num = 0; + + // Parse the ip addresses + size_t cur = 0; + for (;;) { + // Get the start and end of the line + size_t start = cur; + while (cur < size && str[cur] != '\n' && str[cur] != '#') + cur++; + string line = { str + start, cur - start }; + line = trim(line); + + if (line.size > 0) { + + char temp[sizeof("xxx.xxx.xxx.xxx")]; + if (line.size >= sizeof(temp)) { + log_format("Invalid IP address \"%.*s\"\n", (int) line.size, line.data); + close(fd); + free(str); + return -1; + } + memcpy(temp, line.data, line.size); + temp[line.size] = '\0'; + + uint32_t ip; + if (inet_pton(AF_INET, temp, &ip) != 1) { + log_format("Invalid IP address \"%.*s\"\n", (int) line.size, line.data); + close(fd); + free(str); + return -1; + } + + if (blocked_num == BLACKLIST_LIMIT) { + log_format("IP buffer is too short\n"); + close(fd); + free(str); + return -1; + } + + blocked_ips[blocked_num++] = ip; + } + + if (cur < size && str[cur] == '#') + while (cur < size && str[cur] != '\n') + cur++; + + if (cur == size) + break; + assert(str[cur] == '\n'); + cur++; + } + + close(fd); + free(str); + return true; +} + +#endif /* BLACKLIST */ diff --git a/tests/stress.py b/tests/stress.py new file mode 100644 index 0000000..d2ca3a1 --- /dev/null +++ b/tests/stress.py @@ -0,0 +1,110 @@ +import time +import asyncio +import random +import inspect + +HOST = "127.0.0.1" +PORT = 8080 +NCLIENTS = 100 + +def print_bytes(prefix, data): + print(prefix, f"\\r\\n\n{prefix}".join(data.decode("utf-8").split("\r\n")), sep="") + +async def start_sending_request_then_close(client_id, reader, writer): + + print(client_id, inspect.currentframe().f_code.co_name) + + writer.write(b"GET /hello HT") + await writer.drain() + + writer.close() + + +async def close_while_waiting_response(client_id, reader, writer): + + print(client_id, inspect.currentframe().f_code.co_name) + + writer.write(b"GET /hello HTTP/1.1\r\nConnection: Keep-Alive\r\n\r\n") + await writer.drain() + + writer.close() + +async def send_simple_request(client_id, reader, writer): + + print(client_id, inspect.currentframe().f_code.co_name) + + writer.write(b"GET /hello HTTP/1.1\r\nConnection: Keep-Alive\r\n\r\n") + await writer.drain() + + expect = [ + b"HTTP/1.1 200 OK\r\nConnection: Close\r\nContent-Length: 13 \r\n\r\nHello, world!", + b"HTTP/1.1 200 OK\r\nConnection: Keep-Alive\r\nContent-Length: 13 \r\n\r\nHello, world!", + ] + expect.sort(key=lambda e: len(e)) + + unexpected = True + accum = b'' + for res in expect: + + if len(res) > len(accum): + data = await reader.read(len(res) - len(accum)) + if len(data) == 0: + return # We were disconnected + #print_bytes("> ", data) + accum = accum + data + + if accum == res: + unexpected = False + break + + if unexpected: + raise RuntimeError("Unexpected response") + + +async def send_request_pipeline(client_id, reader, writer): + + print(client_id, inspect.currentframe().f_code.co_name) + + base = b"GET /hello HTTP/1.1\r\nConnection: Keep-Alive\r\n\r\n" + pipeline = base * 300 + print(len(pipeline)) + + writer.write(pipeline) + await writer.drain() + + asyncio.sleep(1) + + writer.close() + await writer.wait_closed() + +actions = [ + send_request_pipeline, + start_sending_request_then_close, + close_while_waiting_response, +# send_simple_request, +] + +async def client(client_id): + + reader = None + + while True: + if reader is None or reader.at_eof(): + #print('Connecting') + reader, writer = await asyncio.open_connection(HOST, PORT) + try: + await asyncio.sleep(0.1) + await actions[random.randint(0, len(actions)-1)](client_id, reader, writer) + except ConnectionResetError as e: + print(e) + print('Close the connection') + writer.close() + await writer.wait_closed() + +async def main(): + tasks = [] + for i in range(NCLIENTS): + tasks.append(asyncio.create_task(client(i))) + await asyncio.gather(*tasks) + +asyncio.run(main()) \ No newline at end of file diff --git a/tests/test.py b/tests/test.py index f28ecbe..c05cb0b 100644 --- a/tests/test.py +++ b/tests/test.py @@ -6,7 +6,7 @@ @dataclass class Delay: - ms: float + sec: float @dataclass class Send: @@ -112,6 +112,26 @@ def run_test(test, addr, port): Delay(6), Close() ], + [ + # Test invalid protocol version + Send(b"GET /hello HTTP/2\r\nConnection: Keep-Alive\r\n\r\n"), + Recv(b"HTTP/1.1 505 HTTP Version Not Supported\r\nConnection: Keep-Alive\r\nContent-Length: 0 \r\n\r\n"), + ], + [ + # Send request in pieces + Send(b"GET /hello HT"), + Delay(1), + Send(b"TP/1.1\r\nConnection: Ke"), + Delay(1), + Send(b"ep-Alive\r\n\r\n"), + Recv(b"HTTP/1.1 200 OK\r\nConnection: Keep-Alive\r\nContent-Length: 13 \r\n\r\nHello, world!"), + ], + [ + # Test pipelining + Send(b"GET /hello HTTP/1.1\r\nConnection: Keep-Alive\r\n\r\nGET /hello HTTP/1.1\r\nConnection: Keep-Alive\r\n\r\n"), + Recv(b"HTTP/1.1 200 OK\r\nConnection: Keep-Alive\r\nContent-Length: 13 \r\n\r\nHello, world!"), + Recv(b"HTTP/1.1 200 OK\r\nConnection: Keep-Alive\r\nContent-Length: 13 \r\n\r\nHello, world!"), + ], ] p = subprocess.Popen(['../serve_cov'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)