diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index a5c9a1a..43a2d91 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,4 +1,5 @@ cmake_minimum_required(VERSION 3.17) +add_subdirectory(hw_3_tcp) add_subdirectory(hw_1_process) add_subdirectory(hw_2_logger) diff --git a/lib/hw_3_tcp/CMakeLists.txt b/lib/hw_3_tcp/CMakeLists.txt new file mode 100644 index 0000000..9e8e0b7 --- /dev/null +++ b/lib/hw_3_tcp/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.17) +project(hw_3_tcp) + +set(CMAKE_CXX_STANDARD 17) + +add_library(hw_3_tcp STATIC src/server.cpp src/connection.cpp) +target_include_directories(hw_3_tcp PUBLIC include) + +target_link_libraries(hw_3_tcp hw_1_descriptor) +target_link_libraries(hw_3_tcp hw_2_logger) + +add_executable(test_client_hw_3 tests/test_client.cpp) +target_link_libraries(test_client_hw_3 hw_3_tcp) + +add_executable(test_server_hw_3 tests/test_server.cpp) +target_link_libraries(test_server_hw_3 hw_3_tcp) diff --git a/lib/hw_3_tcp/include/tcp.hpp b/lib/hw_3_tcp/include/tcp.hpp new file mode 100644 index 0000000..93df7a0 --- /dev/null +++ b/lib/hw_3_tcp/include/tcp.hpp @@ -0,0 +1,62 @@ +#pragma once + +#include "descriptor.hpp" +#include +#include + +namespace tcp { + +class TcpError : public std::runtime_error { + using std::runtime_error::runtime_error; +}; + +class ConnectionError : public TcpError { + using TcpError::TcpError; +}; + +class Connection { +public: + Connection() = default; + Connection(desc::Descriptor &&client_socket, std::string addr, uint16_t port); + + Connection(Connection &other) = delete; + Connection(Connection &&other) noexcept; + + Connection &operator=(Connection &other) = delete; + Connection &operator=(Connection &&other) noexcept; + + virtual ~Connection() = default; + + void Close(); + + size_t Read(char *, size_t); + void ReadExact(char *, size_t); + + size_t Write(const char *, size_t); + void WriteExact(const char *, size_t); + + void SetTimeout(size_t sec, size_t ms); + void SetTimeout(const timeval &timeout); + + [[nodiscard]] bool IsOpen() const; + [[nodiscard]] int GetSocket() const; + [[nodiscard]] std::string GetAddr() const; + [[nodiscard]] std::string GetPort() const; + +protected: + desc::Descriptor socket_; + std::optional timeout_; + std::string addr_; + uint16_t port_; + + void SetTimeoutInternal(const timeval &timeout); + void LogErrorAndThrow(); +}; + +class ClientConnection : public Connection { +public: + ClientConnection(const std::string &addr, int port); + void Connect(const std::string &addr, int port); +}; + +} // namespace tcp diff --git a/lib/hw_3_tcp/include/tcp_server.hpp b/lib/hw_3_tcp/include/tcp_server.hpp new file mode 100644 index 0000000..075634b --- /dev/null +++ b/lib/hw_3_tcp/include/tcp_server.hpp @@ -0,0 +1,43 @@ +#pragma once + +#include "descriptor.hpp" +#include "tcp.hpp" +#include +#include + +namespace tcp { + +class ServerError : public TcpError { + using TcpError::TcpError; +}; + +class ServerAcceptError : public ServerError { + using ServerError::ServerError; +}; + +class Server { +public: + Server() = default; + Server(const std::string &addr, int port, size_t max_connections); + + Server(Server &other) = delete; + Server(Server &&other) noexcept; + + Server &operator=(Server &other) = delete; + Server &operator=(Server &&other) noexcept; + + void Open(const std::string &addr, int port, size_t max_connections); + void Close(); + Connection Accept(); + + void SetTimeout(size_t sec, size_t ms); + +private: + desc::Descriptor server_socket_; + std::optional timeout_; + + void LogErrorAndThrow(); + void SetTimeoutInternal(const timeval &timeout); +}; + +} // namespace tcp diff --git a/lib/hw_3_tcp/src/connection.cpp b/lib/hw_3_tcp/src/connection.cpp new file mode 100644 index 0000000..c637365 --- /dev/null +++ b/lib/hw_3_tcp/src/connection.cpp @@ -0,0 +1,209 @@ +#include "tcp.hpp" + +#include "logger.hpp" +#include +#include +#include +#include +#include +#include +#include + +namespace tcp { + +Connection::Connection(Connection &&other) noexcept + : socket_(std::move(other.socket_)) { + + std::swap(addr_, other.addr_); + std::swap(port_, other.port_); + std::swap(timeout_, other.timeout_); +} + +Connection &Connection::operator=(Connection &&other) noexcept { + if (this != &other) { + socket_ = std::move(other.socket_); + + std::swap(addr_, other.addr_); + std::swap(port_, other.port_); + std::swap(timeout_, other.timeout_); + } + return *this; +} + +Connection::Connection(desc::Descriptor &&client_socket, std::string addr, + uint16_t port) + : socket_(std::move(client_socket)), addr_(std::move(addr)), port_(port) {} + +void Connection::Close() { + log::INFO("close connection on fd = " + std::to_string(*socket_)); + socket_.close(); + timeout_ = std::nullopt; +} + +size_t Connection::Read(char *data, size_t len) { + if (!socket_.isValid()) { + log::ERROR("try to read from invalid descriptor"); + throw ConnectionError("can't read: invalid descriptor"); + } + ssize_t bytes_read = ::read(*socket_, data, len); + + if (bytes_read < 0) { + LogErrorAndThrow(); + } + + if (bytes_read == 0) { + log::INFO("other side close connection on fd= " + std::to_string(*socket_)); + socket_.close(); + } + + log::DEBUG("from socket " + std::to_string(*socket_) + + ", read bytes = " + std::to_string(bytes_read)); + return static_cast(bytes_read); +} + +void Connection::ReadExact(char *data, size_t len) { + size_t rest = len; + size_t position = 0; + size_t bytes_read = 0; + while ((bytes_read = Read(static_cast(data) + position, rest)) > 0) { + if (bytes_read >= rest) { + return; + } + rest -= bytes_read; + position += bytes_read; + } + if (bytes_read == 0) { + log::WARN("channel was closed while reading, rest_bytes= " + + std::to_string(rest)); + throw ConnectionError("channel was closed while reading"); + } +} + +size_t Connection::Write(const char *data, size_t len) { + if (!socket_.isValid()) { + log::ERROR("can't write message: invalid descriptor"); + throw ConnectionError("can't write message: invalid descriptor"); + } + ssize_t bytes_wrote = ::write(*socket_, data, len); + + if (bytes_wrote < 0) { + LogErrorAndThrow(); + } + if (bytes_wrote == 0) { + log::INFO("other side close connection on fd= " + std::to_string(*socket_)); + socket_.close(); + } + log::DEBUG("to socket " + std::to_string(*socket_) + + ", wrote bytes = " + std::to_string(bytes_wrote)); + return static_cast(bytes_wrote); +} + +void Connection::WriteExact(const char *data, size_t rest) { + size_t position = 0; + size_t bytes_write = 0; + while ((bytes_write = + Write(static_cast(data) + position, rest)) > 0) { + if (bytes_write >= rest) { + return; + } + rest -= bytes_write; + position += bytes_write; + } + if (bytes_write == 0) { + log::WARN("channel was closed while writing, rest_bytes= " + + std::to_string(rest)); + throw ConnectionError("channel was closed while writing"); + } +} + +void Connection::SetTimeout(size_t sec, size_t ms) { + timeval timeout{}; + timeout.tv_sec = sec; + timeout.tv_usec = ms; + + SetTimeout(timeout); +} + +void Connection::SetTimeout(const timeval &timeout) { + if (!socket_.isValid()) { + timeout_ = timeout; + return; + } + SetTimeoutInternal(timeout); +} + +void Connection::SetTimeoutInternal(const timeval &timeout) { + + assert(socket_.isValid()); + + log::DEBUG("try to set timeout s= " + std::to_string(timeout.tv_sec) + + ", ms= " + std::to_string(timeout.tv_usec)); + + if (::setsockopt(*socket_, SOL_SOCKET, SO_SNDTIMEO, &timeout, + sizeof(timeout)) < 0) { + log::WARN("can't set out timeout to socket: " + std::to_string(*socket_)); + } + if (::setsockopt(*socket_, SOL_SOCKET, SO_RCVTIMEO, &timeout, + sizeof(timeout)) < 0) { + log::WARN("can't set in timeout to socket: " + std::to_string(*socket_)); + } +} + +bool Connection::IsOpen() const { return socket_.isValid(); } + +void Connection::LogErrorAndThrow() { + + socket_.close(); + log::ERROR(std::strerror(errno)); + throw ConnectionError(std::strerror(errno)); +} + +int Connection::GetSocket() const { return *socket_; } + +std::string Connection::GetAddr() const { return addr_; } + +std::string Connection::GetPort() const { return std::to_string(port_); } + +ClientConnection::ClientConnection(const std::string &addr, int port) { + addr_ = addr; + port_ = port; + Connect(addr, port); +} + +void ClientConnection::Connect(const std::string &addr, int port) { + + if (socket_.isValid()) { + log::DEBUG("try to create new connection on p=" + std::to_string(port) + + ", addr= " + addr + "; close old one on "); + Close(); + } else { + log::DEBUG("try to create new connection on p=" + std::to_string(port) + + ", addr= " + addr); + } + + try { + socket_ = desc::Descriptor(::socket(AF_INET, SOCK_STREAM, 0)); + } catch (const desc::DescriptorError &err) { + LogErrorAndThrow(); + } + + if (timeout_) { + SetTimeoutInternal(*timeout_); + } + + sockaddr_in sock_addr{}; + sock_addr.sin_family = AF_INET; + sock_addr.sin_port = ::htons(port); + int result = ::inet_aton(addr.c_str(), &sock_addr.sin_addr); + if (result <= 0) { + LogErrorAndThrow(); + } + + result = ::connect(*socket_, reinterpret_cast(&sock_addr), + sizeof(sock_addr)); + if (result != 0) { + LogErrorAndThrow(); + } +} + +} // namespace tcp diff --git a/lib/hw_3_tcp/src/server.cpp b/lib/hw_3_tcp/src/server.cpp new file mode 100644 index 0000000..a91bf99 --- /dev/null +++ b/lib/hw_3_tcp/src/server.cpp @@ -0,0 +1,139 @@ +#include "tcp_server.hpp" +#include +#include +#include +#include +#include + +namespace tcp { + +Server::Server(const std::string &addr, int port, size_t max_connections) { + Open(addr, port, max_connections); +} +Server::Server(Server &&other) noexcept + : server_socket_(std::move(other.server_socket_)), + timeout_(other.timeout_) { + + other.timeout_.reset(); +} +Server &Server::operator=(Server &&other) noexcept { + if (this != &other) { + server_socket_ = std::move(other.server_socket_); + timeout_ = other.timeout_; + other.timeout_.reset(); + } + return *this; +} + +void Server::Open(const std::string &addr, int port, size_t max_connections) { + + if (server_socket_.isValid()) { + log::DEBUG("try to create new server, listen on p=" + std::to_string(port) + + ", addr= " + addr + "; close old one on "); + Close(); + } else { + log::DEBUG("try to create new server, listen on p=" + std::to_string(port) + + ", addr= " + addr); + } + + try { + server_socket_ = desc::Descriptor{::socket(AF_INET, SOCK_STREAM, 0)}; + } catch (const desc::DescriptorError &err) { + LogErrorAndThrow(); + } + + int opt = 1; + int result; + result = ::setsockopt(*server_socket_, SOL_SOCKET, SO_REUSEADDR, &opt, + sizeof(int)); + if (result < 0) { + LogErrorAndThrow(); + } + + if (timeout_) { + SetTimeoutInternal(*timeout_); + } + + sockaddr_in sock_addr{}; + sock_addr.sin_family = AF_INET; + sock_addr.sin_port = ::htons(8080); + sock_addr.sin_addr = {::htonl(INADDR_ANY)}; + + result = ::bind(*server_socket_, reinterpret_cast(&sock_addr), + sizeof(sock_addr)); + if (result != 0) { + LogErrorAndThrow(); + } + + result = ::listen(*server_socket_, max_connections); + if (result != 0) { + LogErrorAndThrow(); + } +} +void Server::Close() { + + log::INFO("close server on fd = " + std::to_string(*server_socket_)); + server_socket_.close(); + timeout_ = std::nullopt; +} + +void Server::LogErrorAndThrow() { + server_socket_.close(); + log::ERROR(std::strerror(errno)); + throw ConnectionError(std::strerror(errno)); +} + +Connection Server::Accept() { + + sockaddr_in client_sock_addr{}; + socklen_t s = sizeof(sockaddr_in); + + desc::Descriptor client_fd; + try { + client_fd = desc::Descriptor{::accept( + *server_socket_, reinterpret_cast(&client_sock_addr), &s)}; + } catch (const desc::DescriptorError &err) { + log::WARN("server_accept error: " + std::string(std::strerror(errno))); + throw ServerAcceptError("server_accept error: " + + std::string(std::strerror(errno))); + } + + std::string client_addr{::inet_ntoa(client_sock_addr.sin_addr)}; + log::INFO("get client from: addr= " + client_addr + + ", p= " + std::to_string(client_sock_addr.sin_port)); + + return {std::move(client_fd), client_addr, client_sock_addr.sin_port}; +} + +void Server::SetTimeout(size_t sec, size_t ms) { + timeval timeout{}; + timeout.tv_sec = sec; + timeout.tv_usec = ms; + + if (!server_socket_.isValid()) { + timeout_ = timeout; + return; + } + SetTimeoutInternal(timeout); +} + +void Server::SetTimeoutInternal(const timeval &timeout) { + + assert(server_socket_.isValid()); + + log::DEBUG("try to set timeout s= " + std::to_string(timeout.tv_sec) + + ", ms= " + std::to_string(timeout.tv_usec)); + + if (::setsockopt(*server_socket_, SOL_SOCKET, SO_SNDTIMEO, &timeout, + sizeof(timeout)) < 0) { + log::WARN("can't set out timeout to socket: " + + std::to_string(*server_socket_)); + } + if (::setsockopt(*server_socket_, SOL_SOCKET, SO_RCVTIMEO, &timeout, + sizeof(timeout)) < 0) { + log::WARN("can't set in timeout to socket: " + + std::to_string(*server_socket_)); + } +} + +} // namespace tcp diff --git a/lib/hw_3_tcp/tests/test_client.cpp b/lib/hw_3_tcp/tests/test_client.cpp new file mode 100644 index 0000000..8d42d66 --- /dev/null +++ b/lib/hw_3_tcp/tests/test_client.cpp @@ -0,0 +1,15 @@ +#include "logger.hpp" +#include "tcp.hpp" + +int main() { + log::init_with_stderr_logger(log::Level::DEBUG); + + tcp::ClientConnection con{"127.0.0.1", 8080}; + + while (con.IsOpen()) { + std::string buf(1111, '\0'); + size_t read = 0; + read = con.Read(buf.data(), buf.size()); + std::cout << buf.substr(0, read) << std::flush; + } +} diff --git a/lib/hw_3_tcp/tests/test_server.cpp b/lib/hw_3_tcp/tests/test_server.cpp new file mode 100644 index 0000000..835889a --- /dev/null +++ b/lib/hw_3_tcp/tests/test_server.cpp @@ -0,0 +1,34 @@ +#include "logger.hpp" +#include "tcp_server.hpp" + +int main() { + log::init_with_stderr_logger(log::Level::DEBUG); + + tcp::Server server{"127.0.0.1", 8080, 100}; + server.SetTimeout(10, 0); + tcp::Connection con; + try { + con = server.Accept(); + } catch (const tcp::ServerAcceptError &e) { + std::cout << e.what() << std::endl; + } + con.SetTimeout(5, 0); + + try { + while (con.IsOpen()) { + std::string buf(1111, '\0'); + size_t read = 0; + read = con.Read(buf.data(), buf.size()); + std::cout << buf.substr(0, read) << std::flush; + } + } catch (const tcp::ConnectionError &) { + } + con = server.Accept(); + + while (con.IsOpen()) { + std::string buf(1111, '\0'); + size_t read = 0; + read = con.Read(buf.data(), buf.size()); + std::cout << "con2 " << buf.substr(0, read) << std::flush; + } +}