Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions include/libserial/serial.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,26 @@ void setMinNumberCharRead(uint16_t);
*/
void setBaudRate(BaudRate baud_rate);

/**
* @brief Sets the maximum safe read size
*
* Configures the maximum number of bytes that can be read
* in a single read operation to prevent excessive memory usage.
*
* @param size The desired maximum safe read size in bytes
*/
void setMaxSafeReadSize(size_t size);

/**
* @brief Gets the maximum safe read size
*
* Retrieves the maximum number of bytes that can be read
* in a single read operation to prevent excessive memory usage.
*
* @return The maximum safe read size in bytes
*/
size_t getMaxSafeReadSize() const;

/**
* @brief Gets the current baud rate
*
Expand All @@ -318,8 +338,9 @@ int getBaudRate() const;
void setFdForTest(int fd) {
fd_serial_port_ = fd;
}

// For testing - allow injection of mock functions
// WARNING: Test helper only! This function allows injection of custom
// system call functions for testing error handling. It should NEVER be
// used in production code.
void setSystemCallFunctions(
std::function<int(struct pollfd*, nfds_t, int)> poll_func,
std::function<ssize_t(int, void*, size_t)> read_func) {
Expand Down Expand Up @@ -410,8 +431,9 @@ std::chrono::milliseconds write_timeout_ms_{1000}; ///< Write timeout in mill
*
* Defines the maximum number of bytes that can be read
* in a single read operation to prevent excessive memory usage.
* Default is 2048 bytes (2KB).
*/
static constexpr size_t kMaxSafeReadSize = 2048; // 2KB limit
size_t max_safe_read_size_{2048}; // 2KB limit

/**
* @brief Timeout value in milliseconds
Expand Down
19 changes: 14 additions & 5 deletions src/serial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ size_t Serial::read(std::shared_ptr<std::string> buffer) {
}

buffer->clear();
buffer->resize(kMaxSafeReadSize);
buffer->resize(max_safe_read_size_);

struct pollfd fd_poll;
fd_poll.fd = fd_serial_port_;
Expand All @@ -83,7 +83,8 @@ size_t Serial::read(std::shared_ptr<std::string> buffer) {
}

// Data available: do the read
ssize_t bytes_read = read_(fd_serial_port_, const_cast<char*>(buffer->data()), kMaxSafeReadSize);
ssize_t bytes_read = read_(fd_serial_port_, const_cast<char*>(buffer->data()),
max_safe_read_size_);
if (bytes_read < 0) {
throw IOException(std::string("Error reading from serial port: ") + strerror(errno));
}
Expand All @@ -108,7 +109,7 @@ size_t Serial::readBytes(std::shared_ptr<std::string> buffer, size_t num_bytes)
buffer->clear();
buffer->resize(num_bytes);

ssize_t bytes_read = ::read(fd_serial_port_, buffer->data(), num_bytes); // codacy-ignore[buffer-boundary]
ssize_t bytes_read = read_(fd_serial_port_, buffer->data(), num_bytes); // codacy-ignore[buffer-boundary]

if (bytes_read < 0) {
throw IOException("Error reading from serial port: " + std::string(strerror(errno)));
Expand All @@ -130,9 +131,9 @@ size_t Serial::readUntil(std::shared_ptr<std::string> buffer, char terminator) {

while (temp_char != terminator) {
// Check buffer size limit to prevent excessive memory usage
if (buffer->size() >= kMaxSafeReadSize) {
if (buffer->size() >= max_safe_read_size_) {
throw IOException("Read buffer exceeded maximum size limit of " +
std::to_string(kMaxSafeReadSize) +
std::to_string(max_safe_read_size_) +
" bytes without finding terminator");
}
// Check timeout if enabled (0 means no timeout)
Expand Down Expand Up @@ -344,6 +345,14 @@ void Serial::setMinNumberCharRead(uint16_t num) {
this->setTermios2();
}

void Serial::setMaxSafeReadSize(size_t size) {
max_safe_read_size_ = size;
}

size_t Serial::getMaxSafeReadSize() const {
return max_safe_read_size_;
}

int Serial::getAvailableData() const {
int bytes_available;
if (ioctl(fd_serial_port_, FIONREAD, &bytes_available) < 0) {
Expand Down
87 changes: 87 additions & 0 deletions test/test_serial_pty.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,39 @@ TEST_F(PseudoTerminalTest, ReadBytesWithInvalidNumBytes) {
}, libserial::IOException);
}

TEST_F(PseudoTerminalTest, ReadBytesWithReadFail) {
libserial::Serial serial_port;

serial_port.open(slave_port_);
serial_port.setBaudRate(9600);
serial_port.setCanonicalMode(libserial::CanonicalMode::DISABLE);

auto read_buffer = std::make_shared<std::string>();

for (const auto& [error_num, error_msg] : errors_read_) {
serial_port.setSystemCallFunctions(
[](struct pollfd*, nfds_t, int) -> int {
return 1;
},
[error_num](int, void*, size_t) -> ssize_t {
errno = error_num;
return -1;
});

auto expected_what = "Error reading from serial port: " + error_msg;

EXPECT_THROW({
try {
serial_port.readBytes(read_buffer, 10);
}
catch (const libserial::IOException& e) {
EXPECT_STREQ(expected_what.c_str(), e.what());
throw;
}
}, libserial::IOException);
}
}

TEST_F(PseudoTerminalTest, ReadBytesCanonicalMode) {
libserial::Serial serial_port;

Expand Down Expand Up @@ -472,6 +505,25 @@ TEST_F(PseudoTerminalTest, ReadUntil) {
EXPECT_EQ(*read_buffer, "Read Until!");
}

TEST_F(PseudoTerminalTest, ReadUntilWithNullBuffer) {
libserial::Serial serial_port;

serial_port.open(slave_port_);
serial_port.setBaudRate(9600);

std::shared_ptr<std::string> null_buffer;

EXPECT_THROW({
try {
serial_port.readUntil(null_buffer, '!');
}
catch (const libserial::IOException& e) {
EXPECT_STREQ("Null pointer passed to readUntil function", e.what());
throw;
}
}, libserial::IOException);
}

TEST_F(PseudoTerminalTest, ReadUntilTimeout) {
libserial::Serial serial_port;

Expand Down Expand Up @@ -552,3 +604,38 @@ TEST_F(PseudoTerminalTest, ReadUntilWithPollFail) {
}, libserial::IOException);
}
}

TEST_F(PseudoTerminalTest, ReadUntilWithOverflowBuffer) {
libserial::Serial serial_port;

serial_port.open(slave_port_);
serial_port.setBaudRate(9600);
EXPECT_NO_THROW(serial_port.setMaxSafeReadSize(10)); // Set max safe read size to 10 bytes

std::string test_message(15, 'a');
test_message.push_back('\n');

ssize_t bytes_written = write(master_fd_, test_message.c_str(), test_message.length());
ASSERT_GT(bytes_written, 0) << "Failed to write to master end";

// Give time for data to propagate
fsync(master_fd_);
std::this_thread::sleep_for(std::chrono::milliseconds(100));

// Test reading with shared pointer - only read what's available
auto read_buffer = std::make_shared<std::string>();

auto expected_what = "Read buffer exceeded maximum size limit of " +
std::to_string(serial_port.getMaxSafeReadSize()) +
" bytes without finding terminator";

EXPECT_THROW({
try {
serial_port.readUntil(read_buffer, '!');
}
catch (const libserial::IOException& e) {
EXPECT_STREQ(expected_what.c_str(), e.what());
throw;
}
}, libserial::IOException);
}