diff --git a/src/shell/commands/head.rs b/src/shell/commands/head.rs index acfa7c0..cd84cde 100644 --- a/src/shell/commands/head.rs +++ b/src/shell/commands/head.rs @@ -81,47 +81,147 @@ fn copy_lines Result>( Ok(ExecuteResult::from_exit_code(0)) } +fn copy_all_but_last_lines Result>( + writer: &mut ShellPipeWriter, + skip_last: u64, + kill_signal: &KillSignal, + mut read: F, +) -> Result { + // read all content first + let mut content = Vec::new(); + let mut buffer = vec![0; 512]; + loop { + if let Some(exit_code) = kill_signal.aborted_code() { + return Ok(ExecuteResult::from_exit_code(exit_code)); + } + let read_bytes = read(&mut buffer)?; + if read_bytes == 0 { + break; + } + content.extend_from_slice(&buffer[..read_bytes]); + } + + // count total lines + let total_lines = content.iter().filter(|&&b| b == b'\n').count() as u64; + + // output all but the last N lines + if total_lines <= skip_last { + return Ok(ExecuteResult::from_exit_code(0)); + } + let lines_to_print = total_lines - skip_last; + + let mut line_count = 0u64; + let mut start = 0; + for (i, &b) in content.iter().enumerate() { + if let Some(exit_code) = kill_signal.aborted_code() { + return Ok(ExecuteResult::from_exit_code(exit_code)); + } + if b == b'\n' { + line_count += 1; + if line_count <= lines_to_print { + writer.write_all(&content[start..=i])?; + } + start = i + 1; + if line_count >= lines_to_print { + break; + } + } + } + + Ok(ExecuteResult::from_exit_code(0)) +} + fn execute_head(mut context: ShellCommandContext) -> Result { let flags = parse_args(&context.args)?; - if flags.path == "-" { - copy_lines( - &mut context.stdout, - flags.lines, - context.state.kill_signal(), - |buf| context.stdin.read(buf), - 512, - ) - } else { - let path = flags.path; - match File::open(context.state.cwd().join(path)) { - Ok(mut file) => copy_lines( - &mut context.stdout, - flags.lines, - context.state.kill_signal(), - |buf| file.read(buf).map_err(Into::into), - 512, - ), - Err(err) => { - context.stderr.write_line(&format!( - "head: {}: {}", - path.to_string_lossy(), - err - ))?; - Ok(ExecuteResult::from_exit_code(1)) + match flags.lines { + LineCount::First(max_lines) => { + if flags.path == "-" { + copy_lines( + &mut context.stdout, + max_lines, + context.state.kill_signal(), + |buf| context.stdin.read(buf), + 512, + ) + } else { + let path = flags.path; + match File::open(context.state.cwd().join(path)) { + Ok(mut file) => copy_lines( + &mut context.stdout, + max_lines, + context.state.kill_signal(), + |buf| file.read(buf).map_err(Into::into), + 512, + ), + Err(err) => { + context.stderr.write_line(&format!( + "head: {}: {}", + path.to_string_lossy(), + err + ))?; + Ok(ExecuteResult::from_exit_code(1)) + } + } + } + } + LineCount::AllButLast(skip_last) => { + if flags.path == "-" { + copy_all_but_last_lines( + &mut context.stdout, + skip_last, + context.state.kill_signal(), + |buf| context.stdin.read(buf), + ) + } else { + let path = flags.path; + match File::open(context.state.cwd().join(path)) { + Ok(mut file) => copy_all_but_last_lines( + &mut context.stdout, + skip_last, + context.state.kill_signal(), + |buf| file.read(buf).map_err(Into::into), + ), + Err(err) => { + context.stderr.write_line(&format!( + "head: {}: {}", + path.to_string_lossy(), + err + ))?; + Ok(ExecuteResult::from_exit_code(1)) + } + } } } } } +#[derive(Debug, PartialEq, Clone, Copy)] +enum LineCount { + /// print first N lines + First(u64), + /// print all but last N lines + AllButLast(u64), +} + #[derive(Debug, PartialEq)] struct HeadFlags<'a> { path: &'a OsStr, - lines: u64, + lines: LineCount, +} + +fn parse_line_count(s: &str) -> Result { + if let Some(rest) = s.strip_prefix('-') { + let num = rest.parse::()?; + Ok(LineCount::AllButLast(num)) + } else { + let num = s.parse::()?; + Ok(LineCount::First(num)) + } } fn parse_args<'a>(args: &'a [OsString]) -> Result> { let mut path: Option<&'a OsStr> = None; - let mut lines: Option = None; + let mut lines: Option = None; let mut iterator = parse_arg_kinds(args).into_iter(); while let Some(arg) = iterator.next() { match arg { @@ -137,9 +237,11 @@ fn parse_args<'a>(args: &'a [OsString]) -> Result> { } ArgKind::ShortFlag('n') => match iterator.next() { Some(ArgKind::Arg(arg)) => { - let num = arg.to_str().and_then(|a| a.parse::().ok()); - if let Some(num) = num { - lines = Some(num); + if let Some(s) = arg.to_str() { + match parse_line_count(s) { + Ok(count) => lines = Some(count), + Err(_) => bail!("expected a numeric value following -n"), + } } else { bail!("expected a numeric value following -n") } @@ -150,7 +252,7 @@ fn parse_args<'a>(args: &'a [OsString]) -> Result> { if flag == "lines" || flag == "lines=" { bail!("expected a value for --lines"); } else if let Some(arg) = flag.strip_prefix("lines=") { - lines = Some(arg.parse::()?); + lines = Some(parse_line_count(arg)?); } else { arg.bail_unsupported()? } @@ -161,7 +263,7 @@ fn parse_args<'a>(args: &'a [OsString]) -> Result> { Ok(HeadFlags { path: path.unwrap_or(OsStr::new("-")), - lines: lines.unwrap_or(10), + lines: lines.unwrap_or(LineCount::First(10)), }) } @@ -231,56 +333,78 @@ mod test { parse_args(&[]).unwrap(), HeadFlags { path: OsStr::new("-"), - lines: 10 + lines: LineCount::First(10) } ); assert_eq!( parse_args(&["-n".into(), "5".into()]).unwrap(), HeadFlags { path: OsStr::new("-"), - lines: 5 + lines: LineCount::First(5) } ); assert_eq!( parse_args(&["--lines=5".into()]).unwrap(), HeadFlags { path: OsStr::new("-"), - lines: 5 + lines: LineCount::First(5) } ); assert_eq!( parse_args(&["path".into()]).unwrap(), HeadFlags { path: OsStr::new("path"), - lines: 10 + lines: LineCount::First(10) } ); assert_eq!( parse_args(&["-n".into(), "5".into(), "path".into()]).unwrap(), HeadFlags { path: OsStr::new("path"), - lines: 5 + lines: LineCount::First(5) } ); assert_eq!( parse_args(&["--lines=5".into(), "path".into()]).unwrap(), HeadFlags { path: OsStr::new("path"), - lines: 5 + lines: LineCount::First(5) } ); assert_eq!( parse_args(&["path".into(), "-n".into(), "5".into()]).unwrap(), HeadFlags { path: OsStr::new("path"), - lines: 5 + lines: LineCount::First(5) } ); assert_eq!( parse_args(&["path".into(), "--lines=5".into()]).unwrap(), HeadFlags { path: OsStr::new("path"), - lines: 5 + lines: LineCount::First(5) + } + ); + // negative line counts (all but last N) + assert_eq!( + parse_args(&["-n".into(), "-1".into()]).unwrap(), + HeadFlags { + path: OsStr::new("-"), + lines: LineCount::AllButLast(1) + } + ); + assert_eq!( + parse_args(&["-n".into(), "-5".into(), "path".into()]).unwrap(), + HeadFlags { + path: OsStr::new("path"), + lines: LineCount::AllButLast(5) + } + ); + assert_eq!( + parse_args(&["--lines=-3".into()]).unwrap(), + HeadFlags { + path: OsStr::new("-"), + lines: LineCount::AllButLast(3) } ); assert_eq!( @@ -304,4 +428,83 @@ mod test { "unsupported flag: -t" ); } + + #[tokio::test] + async fn copies_all_but_last_lines() { + let (reader, mut writer) = pipe(); + let reader_handle = reader.pipe_to_string_handle(); + let data = b"line1\nline2\nline3\nline4\nline5\n"; + let mut offset = 0; + let result = copy_all_but_last_lines( + &mut writer, + 1, + &KillSignal::default(), + |buffer| { + if offset >= data.len() { + return Ok(0); + } + let read_length = min(buffer.len(), data.len() - offset); + buffer[..read_length] + .copy_from_slice(&data[offset..(offset + read_length)]); + offset += read_length; + Ok(read_length) + }, + ); + drop(writer); + assert_eq!(reader_handle.await.unwrap(), "line1\nline2\nline3\nline4\n"); + assert_eq!(result.unwrap().into_exit_code_and_handles().0, 0); + } + + #[tokio::test] + async fn copies_all_but_last_two_lines() { + let (reader, mut writer) = pipe(); + let reader_handle = reader.pipe_to_string_handle(); + let data = b"line1\nline2\nline3\nline4\nline5\n"; + let mut offset = 0; + let result = copy_all_but_last_lines( + &mut writer, + 2, + &KillSignal::default(), + |buffer| { + if offset >= data.len() { + return Ok(0); + } + let read_length = min(buffer.len(), data.len() - offset); + buffer[..read_length] + .copy_from_slice(&data[offset..(offset + read_length)]); + offset += read_length; + Ok(read_length) + }, + ); + drop(writer); + assert_eq!(reader_handle.await.unwrap(), "line1\nline2\nline3\n"); + assert_eq!(result.unwrap().into_exit_code_and_handles().0, 0); + } + + #[tokio::test] + async fn copies_all_but_last_lines_when_skip_exceeds_total() { + let (reader, mut writer) = pipe(); + let reader_handle = reader.pipe_to_string_handle(); + let data = b"line1\nline2\n"; + let mut offset = 0; + let result = copy_all_but_last_lines( + &mut writer, + 5, + &KillSignal::default(), + |buffer| { + if offset >= data.len() { + return Ok(0); + } + let read_length = min(buffer.len(), data.len() - offset); + buffer[..read_length] + .copy_from_slice(&data[offset..(offset + read_length)]); + offset += read_length; + Ok(read_length) + }, + ); + drop(writer); + // when skip_last >= total_lines, output should be empty + assert_eq!(reader_handle.await.unwrap(), ""); + assert_eq!(result.unwrap().into_exit_code_and_handles().0, 0); + } } diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 362c86c..ba94079 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -953,6 +953,46 @@ async fn head() { .assert_stdout("foo\nbar\nbaz\n") .run() .await; + + // negative line count: -n -1 (all but last 1 line) + TestBuilder::new() + .command("head -n -1") + .stdin("line1\nline2\nline3\nline4\nline5\n") + .assert_stdout("line1\nline2\nline3\nline4\n") + .run() + .await; + + // negative line count: -n -2 (all but last 2 lines) + TestBuilder::new() + .command("head -n -2") + .stdin("line1\nline2\nline3\nline4\nline5\n") + .assert_stdout("line1\nline2\nline3\n") + .run() + .await; + + // negative line count with --lines + TestBuilder::new() + .command("head --lines=-3") + .stdin("line1\nline2\nline3\nline4\nline5\n") + .assert_stdout("line1\nline2\n") + .run() + .await; + + // negative line count with file + TestBuilder::new() + .command("head -n -1 file") + .file("file", "line1\nline2\nline3\nline4\nline5\n") + .assert_stdout("line1\nline2\nline3\nline4\n") + .run() + .await; + + // negative line count where skip >= total lines (empty output) + TestBuilder::new() + .command("head -n -10") + .stdin("line1\nline2\nline3\n") + .assert_stdout("") + .run() + .await; } // Basic integration tests as there are unit tests in the commands