diff --git a/Cargo.lock b/Cargo.lock index ad655f1a..f0e3402b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -361,6 +361,7 @@ dependencies = [ "lru 0.16.2", "mockall", "mockito", + "nix 0.29.0", "once_cell", "owo-colors", "ratatui", diff --git a/Cargo.toml b/Cargo.toml index c8135235..dcab70ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ regex = "1.12.2" lazy_static = "1.5" ctrlc = "3.5.1" signal-hook = "0.3.18" +nix = { version = "0.29", features = ["poll"] } atty = "0.2.14" arrayvec = "0.7.6" smallvec = "1.15.1" diff --git a/README.md b/README.md index 79003a87..a755046f 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,14 @@ bssh -o StrictHostKeyChecking=no user@host bssh -Q cipher ``` +**PTY Session Escape Sequences:** + +Like OpenSSH, bssh supports escape sequences in PTY sessions. These must be typed at the beginning of a line (after pressing Enter): + +| Escape | Description | +|--------|-------------| +| `~.` | Disconnect from the remote host | + ### Port Forwarding ```bash # Local port forwarding (-L) diff --git a/src/pty/session/constants.rs b/src/pty/session/constants.rs index e1a7e6d1..0f6c82c5 100644 --- a/src/pty/session/constants.rs +++ b/src/pty/session/constants.rs @@ -13,6 +13,13 @@ // limitations under the License. //! Terminal constants and key sequence definitions +//! +//! NOTE: Many key sequence constants are currently unused since we switched to +//! raw byte passthrough (see issue #87), but are kept for reference and potential +//! future debugging use. + +// Allow dead code for unused key sequence constants +#![allow(dead_code)] // Buffer size constants for allocation optimization // These values are chosen based on empirical testing and SSH protocol characteristics @@ -20,21 +27,18 @@ /// Maximum size for terminal key sequences (ANSI escape sequences are typically 3-7 bytes) /// Value: 8 bytes - Accommodates the longest standard ANSI sequences (F-keys: ESC[2x~) /// Rationale: Most key sequences are 1-5 bytes, 8 provides safe headroom without waste -#[allow(dead_code)] pub const MAX_KEY_SEQUENCE_SIZE: usize = 8; /// Buffer size for SSH I/O operations (4KB aligns with typical SSH packet sizes) /// Value: 4096 bytes - Matches common SSH packet fragmentation boundaries /// Rationale: SSH protocol commonly uses 4KB packets; larger buffers reduce syscalls /// but increase memory usage. 4KB provides optimal balance for interactive sessions. -#[allow(dead_code)] pub const SSH_IO_BUFFER_SIZE: usize = 4096; /// Maximum size for terminal output chunks processed at once /// Value: 1024 bytes - Balance between responsiveness and efficiency /// Rationale: Smaller chunks improve perceived responsiveness for interactive use, /// while still being large enough to batch terminal escape sequences efficiently. -#[allow(dead_code)] pub const TERMINAL_OUTPUT_CHUNK_SIZE: usize = 1024; /// PTY message channel sizing: @@ -57,7 +61,7 @@ pub const INPUT_POLL_TIMEOUT_MS: u64 = 500; /// - Tasks should check cancellation signal frequently (10-50ms intervals) pub const TASK_CLEANUP_TIMEOUT_MS: u64 = 100; -// Const arrays for frequently used key sequences to avoid repeated allocations +// Const arrays for frequently used key sequences to avoid repeated allocations. /// Control key sequences - frequently used in terminal input pub const CTRL_C_SEQUENCE: &[u8] = &[0x03]; // Ctrl+C (SIGINT) pub const CTRL_D_SEQUENCE: &[u8] = &[0x04]; // Ctrl+D (EOF) diff --git a/src/pty/session/escape_filter.rs b/src/pty/session/escape_filter.rs index 428ac590..d6fdfe28 100644 --- a/src/pty/session/escape_filter.rs +++ b/src/pty/session/escape_filter.rs @@ -855,4 +855,326 @@ mod tests { filter.pending_buffer = b"\x1b]99999999999;test\x07".to_vec(); assert_eq!(filter.parse_osc_param(), None); } + + // ======================================== + // Additional edge case tests + // ======================================== + + #[test] + fn test_escape_at_buffer_boundary() { + let mut filter = EscapeSequenceFilter::new(); + + // ESC at the very end of buffer + let output1 = filter.filter(b"Hello\x1b"); + assert_eq!(output1, b"Hello"); + + // Continue with CSI sequence + let output2 = filter.filter(b"[31mRed"); + assert_eq!(output2, b"\x1b[31mRed"); + } + + #[test] + fn test_consecutive_escape_sequences() { + let mut filter = EscapeSequenceFilter::new(); + + // Multiple consecutive color codes + let input = b"\x1b[31m\x1b[1m\x1b[4mBold Red Underline\x1b[0m"; + let output = filter.filter(input); + assert_eq!(output, input.to_vec()); + } + + #[test] + fn test_interleaved_text_and_sequences() { + let mut filter = EscapeSequenceFilter::new(); + + // Text, escape, text, escape pattern + let input = b"A\x1b[1mB\x1b[0mC"; + let output = filter.filter(input); + assert_eq!(output, input.to_vec()); + } + + #[test] + fn test_empty_input() { + let mut filter = EscapeSequenceFilter::new(); + let output = filter.filter(b""); + assert!(output.is_empty()); + } + + #[test] + fn test_single_escape_byte() { + let mut filter = EscapeSequenceFilter::new(); + let output = filter.filter(b"\x1b"); + assert!(output.is_empty(), "Single ESC should be buffered"); + } + + #[test] + fn test_incomplete_csi_then_text() { + let mut filter = EscapeSequenceFilter::new(); + + // Incomplete CSI + let output1 = filter.filter(b"\x1b["); + assert!(output1.is_empty()); + + // Non-sequence character should flush buffer + // Actually, the filter waits for terminator, so let's test complete sequence + let output2 = filter.filter(b"m"); + assert_eq!(output2, b"\x1b[m"); + } + + #[test] + fn test_osc_with_st_terminator() { + let mut filter = EscapeSequenceFilter::new(); + + // OSC with ST terminator instead of BEL + let input = b"\x1b]0;My Title\x1b\\"; + let output = filter.filter(input); + assert_eq!(output, input.to_vec()); + } + + #[test] + fn test_osc_clipboard_response_filtered() { + let mut filter = EscapeSequenceFilter::new(); + + // OSC 52 clipboard response (base64 encoded) + let input = b"\x1b]52;c;SGVsbG8gV29ybGQ=\x07"; + let output = filter.filter(input); + assert!(output.is_empty(), "OSC 52 response should be filtered"); + } + + #[test] + fn test_csi_with_intermediate_bytes() { + let mut filter = EscapeSequenceFilter::new(); + + // CSI with space as intermediate byte (ESC [ 0 SP q) + let input = b"\x1b[0 q"; // Cursor style + let output = filter.filter(input); + assert_eq!(output, input.to_vec()); + } + + #[test] + fn test_csi_with_multiple_params() { + let mut filter = EscapeSequenceFilter::new(); + + // SGR with many parameters + let input = b"\x1b[38;2;255;128;64;48;2;0;0;0m"; + let output = filter.filter(input); + assert_eq!(output, input.to_vec()); + } + + #[test] + fn test_binary_data_passthrough() { + let mut filter = EscapeSequenceFilter::new(); + + // Raw binary data (not escape sequences) + let input = [0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD]; + let output = filter.filter(&input); + assert_eq!(output, input.to_vec()); + } + + #[test] + fn test_unicode_passthrough() { + let mut filter = EscapeSequenceFilter::new(); + + // UTF-8 encoded text + let input = "Hello δΈ–η•Œ 🌍".as_bytes(); + let output = filter.filter(input); + assert_eq!(output, input.to_vec()); + } + + #[test] + fn test_filter_state_normal_initial() { + let filter = EscapeSequenceFilter::new(); + assert_eq!(filter.state, FilterState::Normal); + } + + #[test] + fn test_default_trait() { + let _filter = EscapeSequenceFilter::default(); + } + + #[test] + fn test_reset_clears_state() { + let mut filter = EscapeSequenceFilter::new(); + + // Put filter in non-normal state + let _ = filter.filter(b"\x1b[?"); + + filter.reset(); + + assert_eq!(filter.state, FilterState::Normal); + assert!(filter.pending_buffer.is_empty()); + assert!(filter.sequence_start.is_none()); + } + + #[test] + fn test_osc_4_color_palette_filtered() { + let mut filter = EscapeSequenceFilter::new(); + + // OSC 4 color palette response + let input = b"\x1b]4;0;rgb:0000/0000/0000\x07"; + let output = filter.filter(input); + assert!(output.is_empty(), "OSC 4 response should be filtered"); + } + + #[test] + fn test_osc_11_background_color_filtered() { + let mut filter = EscapeSequenceFilter::new(); + + // OSC 11 background color response + let input = b"\x1b]11;rgb:ffff/ffff/ffff\x07"; + let output = filter.filter(input); + assert!(output.is_empty(), "OSC 11 response should be filtered"); + } + + #[test] + fn test_dcs_tmux_passthrough() { + let mut filter = EscapeSequenceFilter::new(); + + // tmux passthrough DCS (not a response) + let input = b"\x1bPtmux;\x1b\x1b[31mred\x1b\x1b[0m\x1b\\"; + let output = filter.filter(input); + assert_eq!(output, input.to_vec(), "tmux DCS should pass through"); + } + + #[test] + fn test_multiple_filtered_sequences() { + let mut filter = EscapeSequenceFilter::new(); + + // Multiple XTGETTCAP responses in sequence + let input = b"\x1bP+r736574726762\x1b\\\x1bP+r636c656172\x1b\\"; + let output = filter.filter(input); + assert!( + output.is_empty(), + "Multiple XTGETTCAP responses should all be filtered" + ); + } + + #[test] + fn test_filtered_then_passthrough() { + let mut filter = EscapeSequenceFilter::new(); + + // XTGETTCAP (filtered) followed by normal text + let input = b"\x1bP+r736574726762\x1b\\Hello"; + let output = filter.filter(input); + assert_eq!(output, b"Hello", "Text after filtered sequence should pass"); + } + + #[test] + fn test_passthrough_then_filtered() { + let mut filter = EscapeSequenceFilter::new(); + + // Normal text followed by filtered response + let input = b"Hello\x1bP+r736574726762\x1b\\"; + let output = filter.filter(input); + assert_eq!( + output, b"Hello", + "Text before filtered sequence should pass" + ); + } + + #[test] + fn test_csi_erase_display() { + let mut filter = EscapeSequenceFilter::new(); + + // CSI 2 J - Erase Display + let input = b"\x1b[2J"; + let output = filter.filter(input); + assert_eq!(output, input.to_vec()); + } + + #[test] + fn test_csi_erase_line() { + let mut filter = EscapeSequenceFilter::new(); + + // CSI K - Erase Line + let input = b"\x1b[K"; + let output = filter.filter(input); + assert_eq!(output, input.to_vec()); + } + + #[test] + fn test_csi_scroll_region() { + let mut filter = EscapeSequenceFilter::new(); + + // CSI r - Set Scroll Region + let input = b"\x1b[1;24r"; + let output = filter.filter(input); + assert_eq!(output, input.to_vec()); + } + + #[test] + fn test_csi_save_restore_cursor() { + let mut filter = EscapeSequenceFilter::new(); + + // CSI s / CSI u - Save/Restore Cursor + let input = b"\x1b[s\x1b[u"; + let output = filter.filter(input); + assert_eq!(output, input.to_vec()); + } + + #[test] + fn test_stress_many_sequences() { + let mut filter = EscapeSequenceFilter::new(); + + // Generate many color codes + let mut input = Vec::new(); + for i in 0..100 { + input.extend_from_slice(format!("\x1b[{}mX", 30 + (i % 8)).as_bytes()); + } + input.extend_from_slice(b"\x1b[0m"); + + let output = filter.filter(&input); + assert_eq!(output, input, "All color codes should pass through"); + } + + #[test] + fn test_flush_pending_returns_buffered_data() { + let mut filter = EscapeSequenceFilter::new(); + + // Start an incomplete DCS sequence + let _ = filter.filter(b"\x1bPsomedata"); + + let flushed = filter.flush_pending(); + assert_eq!(flushed, b"\x1bPsomedata"); + } + + #[test] + fn test_newlines_passthrough() { + let mut filter = EscapeSequenceFilter::new(); + + // Various newline combinations + let input = b"Line1\nLine2\rLine3\r\nLine4"; + let output = filter.filter(input); + assert_eq!(output, input.to_vec()); + } + + #[test] + fn test_tabs_passthrough() { + let mut filter = EscapeSequenceFilter::new(); + + // Tab characters + let input = b"Col1\tCol2\tCol3"; + let output = filter.filter(input); + assert_eq!(output, input.to_vec()); + } + + #[test] + fn test_backspace_passthrough() { + let mut filter = EscapeSequenceFilter::new(); + + // Backspace character + let input = b"Hello\x08\x08World"; + let output = filter.filter(input); + assert_eq!(output, input.to_vec()); + } + + #[test] + fn test_bel_passthrough() { + let mut filter = EscapeSequenceFilter::new(); + + // BEL character alone (not as OSC terminator) + let input = b"Alert!\x07"; + let output = filter.filter(input); + assert_eq!(output, input.to_vec()); + } } diff --git a/src/pty/session/input.rs b/src/pty/session/input.rs index be77d54c..f3429926 100644 --- a/src/pty/session/input.rs +++ b/src/pty/session/input.rs @@ -13,6 +13,12 @@ // limitations under the License. //! Input event handling for PTY sessions +//! +//! NOTE: This module is currently unused since we switched to raw byte passthrough +//! (see issue #87), but is kept for reference, testing, and potential future use. + +// Allow dead code for the entire module +#![allow(dead_code)] use super::constants::*; use crossterm::event::{Event, KeyCode, KeyEvent, KeyEventKind, KeyModifiers, MouseEvent}; diff --git a/src/pty/session/local_escape.rs b/src/pty/session/local_escape.rs new file mode 100644 index 00000000..ac0dde3d --- /dev/null +++ b/src/pty/session/local_escape.rs @@ -0,0 +1,408 @@ +// Copyright 2025 Lablup Inc. and Jeongkyu Shin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Local escape sequence handling (OpenSSH-style). +//! +//! Handles sequences like `~.` for disconnect without sending to remote. +//! This matches OpenSSH's behavior for local command sequences. +//! +//! # Supported Escape Sequences +//! - `~.` - Terminate connection (must follow newline) +//! +//! # State Machine +//! The detector uses a state machine to track position in the escape sequence: +//! 1. After newline, wait for `~` +//! 2. After `~`, check for `.` +//! 3. On any other character, reset to waiting for newline + +use smallvec::SmallVec; + +/// Action to take after processing input. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum LocalAction { + /// Disconnect the session + Disconnect, + /// Pass data through to remote (optionally filtered). + /// Reserved for future escape sequences like `~?` (help) or `~~` (send literal tilde). + #[allow(dead_code)] + Passthrough(SmallVec<[u8; 64]>), +} + +/// State machine for detecting `~.` after newline. +/// +/// # Example +/// ```ignore +/// // Example is for documentation only - module is internal +/// let mut detector = LocalEscapeDetector::new(); +/// +/// // Normal input passes through +/// assert_eq!(detector.process(b"hello"), None); +/// +/// // Newline followed by ~. triggers disconnect +/// assert_eq!( +/// detector.process(b"\n~."), +/// Some(LocalAction::Disconnect) +/// ); +/// ``` +pub struct LocalEscapeDetector { + after_newline: bool, + saw_tilde: bool, +} + +impl LocalEscapeDetector { + /// Create a new escape detector. + /// + /// Starts in the "after newline" state to allow `~.` at the + /// beginning of a session. + pub fn new() -> Self { + Self { + after_newline: true, // Start as if after newline + saw_tilde: false, + } + } + + /// Process input and check for local escape sequences. + /// + /// Returns `None` if data should pass through unchanged, or + /// `Some(LocalAction)` if a local escape was detected. + /// + /// # Arguments + /// * `data` - Raw input bytes to process + /// + /// # Returns + /// - `None` - Data should be sent to remote as-is + /// - `Some(LocalAction::Disconnect)` - User requested disconnect + /// - `Some(LocalAction::Passthrough(filtered))` - Send filtered data + /// + /// # Example + /// ```ignore + /// // Example is for documentation only - module is internal + /// match detector.process(b"\n~.") { + /// Some(LocalAction::Disconnect) => { + /// // Close the connection + /// } + /// Some(LocalAction::Passthrough(data)) => { + /// // Send filtered data to remote + /// } + /// None => { + /// // Send data to remote unchanged + /// } + /// } + /// ``` + pub fn process(&mut self, data: &[u8]) -> Option { + for &byte in data { + match byte { + b'\r' | b'\n' => { + self.after_newline = true; + self.saw_tilde = false; + } + b'~' if self.after_newline => { + self.saw_tilde = true; + self.after_newline = false; + } + b'.' if self.saw_tilde => { + // Disconnect sequence detected + return Some(LocalAction::Disconnect); + } + _ => { + self.after_newline = false; + self.saw_tilde = false; + } + } + } + None // Pass through + } + + /// Reset the detector state. + /// + /// Useful when starting a new session or after handling an escape. + /// Currently unused but kept for API completeness and testing. + #[allow(dead_code)] + pub fn reset(&mut self) { + self.after_newline = true; + self.saw_tilde = false; + } +} + +impl Default for LocalEscapeDetector { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_normal_input_passes_through() { + let mut detector = LocalEscapeDetector::new(); + assert_eq!(detector.process(b"hello world"), None); + assert_eq!(detector.process(b"test\n"), None); + } + + #[test] + fn test_disconnect_after_newline() { + let mut detector = LocalEscapeDetector::new(); + detector.process(b"hello\n"); + assert_eq!(detector.process(b"~."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_disconnect_at_start() { + let mut detector = LocalEscapeDetector::new(); + // Starts in "after newline" state + assert_eq!(detector.process(b"~."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_tilde_without_dot() { + let mut detector = LocalEscapeDetector::new(); + detector.process(b"\n"); + assert_eq!(detector.process(b"~x"), None); + // State should reset after non-dot character + } + + #[test] + fn test_dot_without_tilde() { + let mut detector = LocalEscapeDetector::new(); + detector.process(b"\n"); + assert_eq!(detector.process(b"."), None); + } + + #[test] + fn test_tilde_not_after_newline() { + let mut detector = LocalEscapeDetector::new(); + assert_eq!(detector.process(b"x~."), None); + } + + #[test] + fn test_carriage_return_enables_escape() { + let mut detector = LocalEscapeDetector::new(); + detector.process(b"hello\r"); + assert_eq!(detector.process(b"~."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_reset() { + let mut detector = LocalEscapeDetector::new(); + detector.process(b"x"); + detector.reset(); + // After reset, should be in "after newline" state + assert_eq!(detector.process(b"~."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_default() { + let _detector = LocalEscapeDetector::default(); + } + + #[test] + fn test_multiple_sequences() { + let mut detector = LocalEscapeDetector::new(); + + // First sequence + detector.process(b"hello\n"); + assert_eq!(detector.process(b"~."), Some(LocalAction::Disconnect)); + + // Reset for next sequence + detector.reset(); + detector.process(b"world\r"); + assert_eq!(detector.process(b"~."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_partial_sequence_in_chunks() { + let mut detector = LocalEscapeDetector::new(); + + // Process in separate chunks + assert_eq!(detector.process(b"\n"), None); + assert_eq!(detector.process(b"~"), None); + assert_eq!(detector.process(b"."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_data_ending_with_tilde() { + let mut detector = LocalEscapeDetector::new(); + + // Data ends with tilde after newline - state should persist + assert_eq!(detector.process(b"\n~"), None); + // Subsequent dot should trigger disconnect + assert_eq!(detector.process(b"."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_data_ending_with_newline() { + let mut detector = LocalEscapeDetector::new(); + + // Data ends with newline - ready for escape + assert_eq!(detector.process(b"hello\n"), None); + // Subsequent ~. should trigger disconnect + assert_eq!(detector.process(b"~."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_consecutive_newlines() { + let mut detector = LocalEscapeDetector::new(); + + // Multiple consecutive newlines + assert_eq!(detector.process(b"\n\n\n"), None); + // Still in after_newline state + assert_eq!(detector.process(b"~."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_mixed_cr_and_lf() { + let mut detector = LocalEscapeDetector::new(); + + // CRLF sequence + assert_eq!(detector.process(b"\r\n"), None); + assert_eq!(detector.process(b"~."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_lfcr_sequence() { + let mut detector = LocalEscapeDetector::new(); + + // LFCR (unusual but possible) + assert_eq!(detector.process(b"\n\r"), None); + assert_eq!(detector.process(b"~."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_large_buffer_with_escape() { + let mut detector = LocalEscapeDetector::new(); + + // Large buffer with escape sequence in the middle + let mut data = vec![b'x'; 1000]; + data.push(b'\n'); + data.push(b'~'); + data.push(b'.'); + data.extend_from_slice(&[b'y'; 500]); + + // Should detect disconnect at ~. + assert_eq!(detector.process(&data), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_tilde_after_text() { + let mut detector = LocalEscapeDetector::new(); + + // Tilde in the middle of text (not after newline) + assert_eq!(detector.process(b"hello~.world"), None); + } + + #[test] + fn test_multiple_tildes() { + let mut detector = LocalEscapeDetector::new(); + + // Multiple tildes after newline + assert_eq!(detector.process(b"\n~~."), None); + // Second tilde resets the state + } + + #[test] + fn test_tilde_then_newline() { + let mut detector = LocalEscapeDetector::new(); + + // Tilde then newline resets + assert_eq!(detector.process(b"\n~\n"), None); + // Should be in after_newline state again + assert_eq!(detector.process(b"~."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_empty_input() { + let mut detector = LocalEscapeDetector::new(); + + // Empty input should not change state + assert_eq!(detector.process(b""), None); + // Still in initial after_newline state + assert_eq!(detector.process(b"~."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_escape_in_binary_data() { + let mut detector = LocalEscapeDetector::new(); + + // Binary data with escape sequence + let data = [0x00, 0xFF, b'\n', b'~', b'.', 0x7F]; + assert_eq!(detector.process(&data), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_only_newline_then_only_tilde() { + let mut detector = LocalEscapeDetector::new(); + + // Single byte inputs + assert_eq!(detector.process(b"\n"), None); + assert_eq!(detector.process(b"~"), None); + // State: saw_tilde = true, after_newline = false + assert_eq!(detector.process(b"."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_state_after_non_dot() { + let mut detector = LocalEscapeDetector::new(); + + // After ~x, state should reset + assert_eq!(detector.process(b"\n~x"), None); + // Need another newline before ~. + assert_eq!(detector.process(b"~."), None); + // Now with newline + assert_eq!(detector.process(b"\n~."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_rapid_escape_attempts() { + let mut detector = LocalEscapeDetector::new(); + + // Rapid repeated attempts + assert_eq!( + detector.process(b"\n~x\n~y\n~z\n~."), + Some(LocalAction::Disconnect) + ); + } + + #[test] + fn test_unicode_does_not_interfere() { + let mut detector = LocalEscapeDetector::new(); + + // UTF-8 encoded characters should not interfere + assert_eq!(detector.process("ν•œκΈ€\n".as_bytes()), None); + assert_eq!(detector.process(b"~."), Some(LocalAction::Disconnect)); + } + + #[test] + fn test_local_action_eq() { + // Test LocalAction equality + assert_eq!(LocalAction::Disconnect, LocalAction::Disconnect); + } + + #[test] + fn test_local_action_debug() { + // Test LocalAction debug implementation + let action = LocalAction::Disconnect; + let debug_str = format!("{:?}", action); + assert!(debug_str.contains("Disconnect")); + } + + #[test] + fn test_local_action_clone() { + // Test LocalAction clone + let action = LocalAction::Disconnect; + let cloned = action.clone(); + assert_eq!(action, cloned); + } +} diff --git a/src/pty/session/mod.rs b/src/pty/session/mod.rs index 02b047df..b0504c00 100644 --- a/src/pty/session/mod.rs +++ b/src/pty/session/mod.rs @@ -17,6 +17,8 @@ mod constants; mod escape_filter; mod input; +mod local_escape; +mod raw_input; mod session_manager; mod terminal_modes; diff --git a/src/pty/session/raw_input.rs b/src/pty/session/raw_input.rs new file mode 100644 index 00000000..925167ec --- /dev/null +++ b/src/pty/session/raw_input.rs @@ -0,0 +1,253 @@ +// Copyright 2025 Lablup Inc. and Jeongkyu Shin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Raw byte input reader for PTY sessions. +//! +//! Reads stdin as raw bytes without escape sequence parsing, +//! providing transparent passthrough like OpenSSH. +//! +//! # Prerequisites +//! This module requires `crossterm::terminal::enable_raw_mode()` to be called +//! before reading. The raw mode ensures: +//! - No line buffering (bytes available immediately) +//! - No echo (typed characters not displayed by terminal) +//! - No signal generation (Ctrl+C doesn't generate SIGINT) +//! +//! # Why Raw Bytes? +//! Using crossterm's `event::read()` parses escape sequences, which consumes +//! the ESC byte (0x1b) and corrupts terminal responses. Reading raw bytes with +//! `stdin.read()` provides transparent passthrough of all bytes, including: +//! - Terminal query responses (DA1, DA2, DA3, XTGETTCAP, etc.) +//! - Arrow keys (`\x1b[A`, `\x1b[B`, `\x1b[C`, `\x1b[D`) +//! - Function keys (`\x1bOP`, `\x1bOQ`, etc.) +//! - Mouse events +//! +//! This approach matches OpenSSH's behavior. + +use std::io::{self, Read}; +use std::os::unix::io::AsRawFd; +use std::time::Duration; + +/// Raw input reader that provides transparent byte passthrough. +/// +/// # Usage +/// ```ignore +/// // Example is for documentation only - module is internal +/// use std::time::Duration; +/// +/// // Ensure raw mode is enabled first +/// crossterm::terminal::enable_raw_mode().unwrap(); +/// +/// let mut reader = RawInputReader::new(); +/// let mut buffer = [0u8; 1024]; +/// +/// if reader.poll(Duration::from_millis(100)).unwrap() { +/// let n = reader.read(&mut buffer).unwrap(); +/// // Process raw bytes... +/// } +/// +/// crossterm::terminal::disable_raw_mode().unwrap(); +/// ``` +pub struct RawInputReader { + stdin: io::Stdin, +} + +impl RawInputReader { + /// Create a new raw input reader. + /// + /// # Prerequisites + /// The terminal must be in raw mode (via `enable_raw_mode()`) before + /// calling `read()` to ensure immediate byte availability. + pub fn new() -> Self { + Self { stdin: io::stdin() } + } + + /// Poll for available input with timeout. + /// + /// Returns `Ok(true)` if data is available to read, `Ok(false)` if timeout + /// occurred, or an error if the poll failed. + /// + /// # Arguments + /// * `timeout` - Maximum time to wait for input. Values greater than 65535ms + /// will be clamped to 65535ms due to poll() limitations. + /// + /// # Example + /// ```ignore + /// // Example is for documentation only - module is internal + /// use std::time::Duration; + /// let reader = RawInputReader::new(); + /// if reader.poll(Duration::from_millis(100))? { + /// // Data is available + /// } + /// ``` + pub fn poll(&self, timeout: Duration) -> io::Result { + use nix::poll::{poll, PollFd, PollFlags, PollTimeout}; + use std::os::unix::io::BorrowedFd; + + let fd = self.stdin.as_raw_fd(); + // SAFETY: + // 1. We hold a reference to `self.stdin` for the entire function scope + // 2. `stdin` is owned by this struct and cannot be closed externally + // 3. The BorrowedFd is used only within this function and not stored + let borrowed_fd = unsafe { BorrowedFd::borrow_raw(fd) }; + let mut poll_fds = [PollFd::new(borrowed_fd, PollFlags::POLLIN)]; + + // Convert Duration to PollTimeout + // PollTimeout accepts u16 in milliseconds (or Option for -1) + let timeout_ms = timeout.as_millis().min(u16::MAX as u128) as u16; + let poll_timeout = PollTimeout::from(timeout_ms); + + match poll(&mut poll_fds, poll_timeout) { + Ok(n) => Ok(n > 0), + Err(nix::errno::Errno::EINTR) => Ok(false), // Interrupted, treat as timeout + Err(e) => Err(io::Error::from_raw_os_error(e as i32)), + } + } + + /// Read available bytes from stdin. + /// + /// Returns the number of bytes read. A return value of 0 indicates EOF. + /// + /// # Raw Mode Behavior + /// When terminal is in raw mode (via `enable_raw_mode()`), this returns + /// raw bytes including escape sequences like: + /// - Arrow keys: `\x1b[A`, `\x1b[B`, `\x1b[C`, `\x1b[D` + /// - Function keys: `\x1bOP`, `\x1bOQ`, etc. + /// - Terminal responses: `\x1b[>64;2500;0c`, etc. + /// - Mouse events: `\x1b[<...M` + /// + /// All bytes are passed through as-is without interpretation. + /// + /// # Example + /// ```ignore + /// // Example is for documentation only - module is internal + /// let mut reader = RawInputReader::new(); + /// let mut buffer = [0u8; 1024]; + /// + /// match reader.read(&mut buffer)? { + /// 0 => println!("EOF"), + /// n => println!("Read {} bytes", n), + /// } + /// ``` + pub fn read(&mut self, buffer: &mut [u8]) -> io::Result { + self.stdin.read(buffer) + } +} + +impl Default for RawInputReader { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_raw_input_reader_creation() { + let _reader = RawInputReader::new(); + // If we can create it, the test passes + } + + #[test] + fn test_default() { + let _reader = RawInputReader::default(); + } + + #[test] + fn test_poll_timeout() { + let reader = RawInputReader::new(); + // Short timeout should return false when no input + let result = reader.poll(Duration::from_millis(10)); + assert!(result.is_ok()); + // We can't guarantee false since input might be available + } + + #[test] + fn test_poll_timeout_clamping_at_u16_max() { + let reader = RawInputReader::new(); + // Verify poll accepts values above u16::MAX (65535ms) + // The implementation clamps to u16::MAX internally + let result = reader.poll(Duration::from_millis(70000)); + assert!(result.is_ok()); + } + + #[test] + fn test_poll_timeout_very_large_duration() { + let reader = RawInputReader::new(); + // Test with very large duration (1 hour) + // Should be clamped to 65535ms + let result = reader.poll(Duration::from_secs(3600)); + assert!(result.is_ok()); + } + + #[test] + fn test_poll_zero_timeout() { + let reader = RawInputReader::new(); + // Zero timeout should return immediately + let result = reader.poll(Duration::ZERO); + assert!(result.is_ok()); + } + + #[test] + fn test_poll_one_millisecond_timeout() { + let reader = RawInputReader::new(); + // Very short timeout + let result = reader.poll(Duration::from_millis(1)); + assert!(result.is_ok()); + } + + #[test] + fn test_poll_exactly_u16_max() { + let reader = RawInputReader::new(); + // Test exactly at the boundary (65535ms) + let result = reader.poll(Duration::from_millis(u16::MAX as u64)); + assert!(result.is_ok()); + } + + #[test] + fn test_poll_just_over_u16_max() { + let reader = RawInputReader::new(); + // Test just over the boundary (65536ms) + let result = reader.poll(Duration::from_millis(u16::MAX as u64 + 1)); + assert!(result.is_ok()); + } + + #[test] + fn test_multiple_sequential_polls() { + let reader = RawInputReader::new(); + // Multiple polls should work consistently + for _ in 0..5 { + let result = reader.poll(Duration::from_millis(1)); + assert!(result.is_ok()); + } + } + + #[test] + fn test_poll_with_nanoseconds() { + let reader = RawInputReader::new(); + // Duration with nanoseconds (will be truncated to milliseconds) + let result = reader.poll(Duration::from_nanos(1_500_000)); // 1.5ms -> 1ms + assert!(result.is_ok()); + } + + #[test] + fn test_poll_sub_millisecond() { + let reader = RawInputReader::new(); + // Sub-millisecond duration (should become 0ms) + let result = reader.poll(Duration::from_micros(500)); // 0.5ms -> 0ms + assert!(result.is_ok()); + } +} diff --git a/src/pty/session/session_manager.rs b/src/pty/session/session_manager.rs index 93413fe2..b79bdf1a 100644 --- a/src/pty/session/session_manager.rs +++ b/src/pty/session/session_manager.rs @@ -16,7 +16,8 @@ use super::constants::*; use super::escape_filter::EscapeSequenceFilter; -use super::input::handle_input_event; +use super::local_escape::{LocalAction, LocalEscapeDetector}; +use super::raw_input::RawInputReader; use super::terminal_modes::configure_terminal_modes; use crate::pty::{ terminal::{TerminalOps, TerminalStateGuard}, @@ -223,34 +224,70 @@ impl PtySession { let cancel_for_input = self.cancel_rx.clone(); // Spawn input reader in blocking thread pool to avoid blocking async runtime + // NOTE: TerminalStateGuard has already called enable_raw_mode() at this point, + // so stdin.read() will return raw bytes without line buffering let input_task = tokio::task::spawn_blocking(move || { - // This runs in a dedicated thread pool for blocking operations + let mut reader = RawInputReader::new(); + let mut buffer = [0u8; 1024]; + let mut escape_detector = LocalEscapeDetector::new(); + loop { if *cancel_for_input.borrow() { break; } - // Poll with timeout since we're in blocking thread let poll_timeout = Duration::from_millis(INPUT_POLL_TIMEOUT_MS); - // Check for input events with timeout (blocking is OK here) - if crossterm::event::poll(poll_timeout).unwrap_or(false) { - match crossterm::event::read() { - Ok(event) => { - if let Some(data) = handle_input_event(event) { - // Use try_send to avoid blocking on bounded channel - if input_tx.try_send(PtyMessage::LocalInput(data)).is_err() { - // Channel is either full or closed - // For input, we should break on error as it means session is ending - break; + match reader.poll(poll_timeout) { + Ok(true) => { + match reader.read(&mut buffer) { + Ok(0) => { + // EOF - user closed stdin + tracing::debug!("EOF received on stdin"); + break; + } + Ok(n) => { + // Check for local escape sequences (e.g., ~. for disconnect) + if let Some(action) = escape_detector.process(&buffer[..n]) { + match action { + LocalAction::Disconnect => { + tracing::debug!("Disconnect escape sequence detected"); + let _ = input_tx.try_send(PtyMessage::Terminate); + break; + } + LocalAction::Passthrough(data) => { + // Send filtered data + if input_tx + .try_send(PtyMessage::LocalInput(data)) + .is_err() + { + break; + } + } + } + } else { + // Pass raw bytes through as-is + // This includes arrow keys, function keys, terminal responses, etc. + let data = smallvec::SmallVec::from_slice(&buffer[..n]); + if input_tx.try_send(PtyMessage::LocalInput(data)).is_err() { + break; + } } } + Err(e) => { + let _ = input_tx + .try_send(PtyMessage::Error(format!("Input error: {e}"))); + break; + } } - Err(e) => { - let _ = - input_tx.try_send(PtyMessage::Error(format!("Input error: {e}"))); - break; - } + } + Ok(false) => { + // Timeout - continue polling + continue; + } + Err(e) => { + let _ = input_tx.try_send(PtyMessage::Error(format!("Poll error: {e}"))); + break; } } } diff --git a/src/pty/session/terminal_modes.rs b/src/pty/session/terminal_modes.rs index af79964e..ce46d331 100644 --- a/src/pty/session/terminal_modes.rs +++ b/src/pty/session/terminal_modes.rs @@ -89,3 +89,292 @@ pub fn configure_terminal_modes() -> Vec<(Pty, u32)> { (Pty::TTY_OP_OSPEED, 38400), // Output baud rate ] } + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper function to find a mode's value in the modes list + fn find_mode(modes: &[(Pty, u32)], target: Pty) -> Option { + modes.iter().find(|(k, _)| *k == target).map(|(_, v)| *v) + } + + #[test] + fn test_configure_terminal_modes_returns_non_empty() { + let modes = configure_terminal_modes(); + assert!(!modes.is_empty(), "Terminal modes should not be empty"); + } + + #[test] + fn test_configure_terminal_modes_count() { + let modes = configure_terminal_modes(); + // We expect a comprehensive set of terminal modes + // Currently 38 modes: 14 control chars + 12 input modes + 10 local modes + 5 output modes + 3 control modes + 2 baud rates - some overlap + assert!( + modes.len() >= 30, + "Expected at least 30 terminal modes, got {}", + modes.len() + ); + } + + #[test] + fn test_control_characters_configured() { + let modes = configure_terminal_modes(); + + // Verify critical control characters + assert_eq!( + find_mode(&modes, Pty::VINTR), + Some(0x03), + "VINTR should be Ctrl+C (0x03)" + ); + assert_eq!( + find_mode(&modes, Pty::VEOF), + Some(0x04), + "VEOF should be Ctrl+D (0x04)" + ); + assert_eq!( + find_mode(&modes, Pty::VSUSP), + Some(0x1A), + "VSUSP should be Ctrl+Z (0x1A)" + ); + assert_eq!( + find_mode(&modes, Pty::VERASE), + Some(0x7F), + "VERASE should be DEL (0x7F)" + ); + assert_eq!( + find_mode(&modes, Pty::VKILL), + Some(0x15), + "VKILL should be Ctrl+U (0x15)" + ); + } + + #[test] + fn test_signal_generation_enabled() { + let modes = configure_terminal_modes(); + + // ISIG enables signal generation (critical for Ctrl+C, Ctrl+Z) + assert_eq!( + find_mode(&modes, Pty::ISIG), + Some(1), + "ISIG should be enabled for signal generation" + ); + } + + #[test] + fn test_canonical_mode_enabled() { + let modes = configure_terminal_modes(); + + // ICANON enables line editing (backspace, etc.) + assert_eq!( + find_mode(&modes, Pty::ICANON), + Some(1), + "ICANON should be enabled for line editing" + ); + } + + #[test] + fn test_echo_enabled() { + let modes = configure_terminal_modes(); + + // ECHO enables character echo (programs can disable for passwords) + assert_eq!( + find_mode(&modes, Pty::ECHO), + Some(1), + "ECHO should be enabled by default" + ); + } + + #[test] + fn test_cr_to_nl_mapping() { + let modes = configure_terminal_modes(); + + // ICRNL maps CR to NL (Enter key works correctly) + assert_eq!( + find_mode(&modes, Pty::ICRNL), + Some(1), + "ICRNL should be enabled for Enter key" + ); + } + + #[test] + fn test_output_processing() { + let modes = configure_terminal_modes(); + + // OPOST enables output processing + assert_eq!( + find_mode(&modes, Pty::OPOST), + Some(1), + "OPOST should be enabled for output processing" + ); + // ONLCR maps NL to CR-NL (proper line endings) + assert_eq!( + find_mode(&modes, Pty::ONLCR), + Some(1), + "ONLCR should be enabled for proper line endings" + ); + } + + #[test] + fn test_8bit_character_size() { + let modes = configure_terminal_modes(); + + // CS8 enables 8-bit characters + assert_eq!( + find_mode(&modes, Pty::CS8), + Some(1), + "CS8 should be enabled for 8-bit characters" + ); + } + + #[test] + fn test_flow_control_disabled() { + let modes = configure_terminal_modes(); + + // Flow control disabled so Ctrl+S/Ctrl+Q work normally + assert_eq!( + find_mode(&modes, Pty::IXON), + Some(0), + "IXON should be disabled (no flow control)" + ); + assert_eq!( + find_mode(&modes, Pty::IXOFF), + Some(0), + "IXOFF should be disabled (no flow control)" + ); + } + + #[test] + fn test_baud_rates() { + let modes = configure_terminal_modes(); + + // Baud rates should be set (nominal values) + assert_eq!( + find_mode(&modes, Pty::TTY_OP_ISPEED), + Some(38400), + "Input baud rate should be 38400" + ); + assert_eq!( + find_mode(&modes, Pty::TTY_OP_OSPEED), + Some(38400), + "Output baud rate should be 38400" + ); + } + + #[test] + fn test_parity_disabled() { + let modes = configure_terminal_modes(); + + assert_eq!( + find_mode(&modes, Pty::PARENB), + Some(0), + "Parity should be disabled" + ); + } + + #[test] + fn test_disabled_control_chars_set_to_0xff() { + let modes = configure_terminal_modes(); + + // Disabled control characters should be 0xFF + assert_eq!( + find_mode(&modes, Pty::VEOL), + Some(0xFF), + "VEOL should be disabled (0xFF)" + ); + assert_eq!( + find_mode(&modes, Pty::VEOL2), + Some(0xFF), + "VEOL2 should be disabled (0xFF)" + ); + } + + #[test] + fn test_extended_input_processing() { + let modes = configure_terminal_modes(); + + // IEXTEN enables extended processing (Ctrl+V literal, etc.) + assert_eq!( + find_mode(&modes, Pty::IEXTEN), + Some(1), + "IEXTEN should be enabled for extended input" + ); + } + + #[test] + fn test_no_duplicate_modes() { + let modes = configure_terminal_modes(); + + for (i, (mode_i, _)) in modes.iter().enumerate() { + for (j, (mode_j, _)) in modes.iter().enumerate() { + if i != j { + assert!( + mode_i != mode_j, + "Duplicate terminal mode found: {:?}", + mode_i + ); + } + } + } + } + + #[test] + fn test_all_control_chars_present() { + let modes = configure_terminal_modes(); + + // Check all expected control characters are present + let control_chars = [ + Pty::VINTR, + Pty::VQUIT, + Pty::VERASE, + Pty::VKILL, + Pty::VEOF, + Pty::VEOL, + Pty::VEOL2, + Pty::VSTART, + Pty::VSTOP, + Pty::VSUSP, + Pty::VREPRINT, + Pty::VWERASE, + Pty::VLNEXT, + Pty::VDISCARD, + ]; + + for ctrl in control_chars { + assert!( + find_mode(&modes, ctrl).is_some(), + "Control character {:?} should be present", + ctrl + ); + } + } + + #[test] + fn test_xon_xoff_chars() { + let modes = configure_terminal_modes(); + + // VSTART (Ctrl+Q) and VSTOP (Ctrl+S) should be configured + assert_eq!( + find_mode(&modes, Pty::VSTART), + Some(0x11), + "VSTART should be Ctrl+Q (0x11)" + ); + assert_eq!( + find_mode(&modes, Pty::VSTOP), + Some(0x13), + "VSTOP should be Ctrl+S (0x13)" + ); + } + + #[test] + fn test_visual_erase_enabled() { + let modes = configure_terminal_modes(); + + // ECHOE enables visual erase (backspace removes characters visually) + assert_eq!( + find_mode(&modes, Pty::ECHOE), + Some(1), + "ECHOE should be enabled for visual erase" + ); + } +}