Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions benchmark/check_until.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
prelude: |-
$LOAD_PATH.unshift(File.expand_path("lib"))
require "strscan"
scanner = StringScanner.new("test string")
str = "string"
reg = /string/
benchmark:
regexp: |
scanner.check_until(/string/)
regexp_var: |
scanner.check_until(reg)
string: |
scanner.check_until("string")
string_var: |
scanner.check_until(str)
30 changes: 13 additions & 17 deletions ext/jruby/org/jruby/ext/strscan/RubyStringScanner.java
Original file line number Diff line number Diff line change
Expand Up @@ -262,17 +262,6 @@ private IRubyObject extractBegLen(Ruby runtime, int beg, int len) {
// MRI: strscan_do_scan
private IRubyObject scan(ThreadContext context, IRubyObject regex, boolean succptr, boolean getstr, boolean headonly) {
final Ruby runtime = context.runtime;

if (headonly) {
if (!(regex instanceof RubyRegexp)) {
regex = regex.convertToString();
}
} else {
if (!(regex instanceof RubyRegexp)) {
throw runtime.newTypeError("wrong argument type " + regex.getMetaClass() + " (expected Regexp)");
}
}

check(context);

ByteList strBL = str.getByteList();
Expand Down Expand Up @@ -310,9 +299,9 @@ private IRubyObject scan(ThreadContext context, IRubyObject regex, boolean succp
}
if (ret < 0) return context.nil;
} else {
RubyString pattern = (RubyString) regex;
RubyString pattern = regex.convertToString();

str.checkEncoding(pattern);
Encoding patternEnc = str.checkEncoding(pattern);

if (restLen() < pattern.size()) {
return context.nil;
Expand All @@ -321,11 +310,18 @@ private IRubyObject scan(ThreadContext context, IRubyObject regex, boolean succp
ByteList patternBL = pattern.getByteList();
int patternSize = patternBL.realSize();

if (ByteList.memcmp(strBL.unsafeBytes(), strBeg + curr, patternBL.unsafeBytes(), patternBL.begin(), patternSize) != 0) {
return context.nil;
if (headonly) {
if (ByteList.memcmp(strBL.unsafeBytes(), strBeg + curr, patternBL.unsafeBytes(), patternBL.begin(), patternSize) != 0) {
return context.nil;
}
setRegisters(patternSize);
} else {
int pos = StringSupport.index(strBL, patternBL, strBeg + curr, patternEnc);
if (pos == -1) {
return context.nil;
}
setRegisters(patternSize + pos - curr);
}

setRegisters(patternSize);
}

setMatched();
Expand Down
25 changes: 14 additions & 11 deletions ext/strscan/strscan.c
Original file line number Diff line number Diff line change
Expand Up @@ -686,14 +686,6 @@ strscan_do_scan(VALUE self, VALUE pattern, int succptr, int getstr, int headonly
{
struct strscanner *p;

if (headonly) {
if (!RB_TYPE_P(pattern, T_REGEXP)) {
StringValue(pattern);
}
}
else {
Check_Type(pattern, T_REGEXP);
}
GET_SCANNER(self, p);

CLEAR_MATCH_STATUS(p);
Expand All @@ -714,14 +706,25 @@ strscan_do_scan(VALUE self, VALUE pattern, int succptr, int getstr, int headonly
}
}
else {
StringValue(pattern);
rb_enc_check(p->str, pattern);
if (S_RESTLEN(p) < RSTRING_LEN(pattern)) {
return Qnil;
}
if (memcmp(CURPTR(p), RSTRING_PTR(pattern), RSTRING_LEN(pattern)) != 0) {
return Qnil;

if (headonly) {
if (memcmp(CURPTR(p), RSTRING_PTR(pattern), RSTRING_LEN(pattern)) != 0) {
return Qnil;
}
set_registers(p, RSTRING_LEN(pattern));
} else {
long pos = rb_memsearch(RSTRING_PTR(pattern), RSTRING_LEN(pattern),
CURPTR(p), S_RESTLEN(p), rb_enc_get(pattern));
if (pos == -1) {
return Qnil;
}
set_registers(p, RSTRING_LEN(pattern) + pos);
}
set_registers(p, RSTRING_LEN(pattern));
}

MATCHED(p);
Expand Down
80 changes: 75 additions & 5 deletions test/strscan/test_stringscanner.rb
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,15 @@ def test_concat
end

def test_scan
s = create_string_scanner('stra strb strc', true)
s = create_string_scanner("stra strb\0strc", true)
tmp = s.scan(/\w+/)
assert_equal 'stra', tmp

tmp = s.scan(/\s+/)
assert_equal ' ', tmp

assert_equal 'strb', s.scan(/\w+/)
assert_equal ' ', s.scan(/\s+/)
assert_equal "\u0000", s.scan(/\0/)

tmp = s.scan(/\w+/)
assert_equal 'strc', tmp
Expand Down Expand Up @@ -312,11 +312,14 @@ def test_scan
end

def test_scan_string
s = create_string_scanner('stra strb strc')
s = create_string_scanner("stra strb\0strc")
assert_equal 'str', s.scan('str')
assert_equal 'str', s[0]
assert_equal 3, s.pos
assert_equal 'a ', s.scan('a ')
assert_equal 'strb', s.scan('strb')
assert_equal "\u0000", s.scan("\0")
assert_equal 'strc', s.scan('strc')

str = 'stra strb strc'.dup
s = create_string_scanner(str, false)
Expand Down Expand Up @@ -668,13 +671,47 @@ def test_exist_p
assert_equal(nil, s.exist?(/e/))
end

def test_exist_p_string
def test_exist_p_invalid_argument
s = create_string_scanner("test string")
assert_raise(TypeError) do
s.exist?(" ")
s.exist?(1)
end
end

def test_exist_p_string
omit("not implemented on TruffleRuby") if RUBY_ENGINE == "truffleruby"
s = create_string_scanner("test string")
assert_equal(3, s.exist?("s"))
assert_equal(0, s.pos)
s.scan("test")
assert_equal(2, s.exist?("s"))
assert_equal(4, s.pos)
assert_equal(nil, s.exist?("e"))
end

def test_scan_until
s = create_string_scanner("Foo Bar\0Baz")
assert_equal("Foo", s.scan_until(/Foo/))
assert_equal(3, s.pos)
assert_equal(" Bar", s.scan_until(/Bar/))
assert_equal(7, s.pos)
assert_equal(nil, s.skip_until(/Qux/))
assert_equal("\u0000Baz", s.scan_until(/Baz/))
assert_equal(11, s.pos)
end

def test_scan_until_string
omit("not implemented on TruffleRuby") if RUBY_ENGINE == "truffleruby"
s = create_string_scanner("Foo Bar\0Baz")
assert_equal("Foo", s.scan_until("Foo"))
assert_equal(3, s.pos)
assert_equal(" Bar", s.scan_until("Bar"))
assert_equal(7, s.pos)
assert_equal(nil, s.skip_until("Qux"))
assert_equal("\u0000Baz", s.scan_until("Baz"))
assert_equal(11, s.pos)
end

def test_skip_until
s = create_string_scanner("Foo Bar Baz")
assert_equal(3, s.skip_until(/Foo/))
Expand All @@ -684,6 +721,16 @@ def test_skip_until
assert_equal(nil, s.skip_until(/Qux/))
end

def test_skip_until_string
omit("not implemented on TruffleRuby") if RUBY_ENGINE == "truffleruby"
s = create_string_scanner("Foo Bar Baz")
assert_equal(3, s.skip_until("Foo"))
assert_equal(3, s.pos)
assert_equal(4, s.skip_until("Bar"))
assert_equal(7, s.pos)
assert_equal(nil, s.skip_until("Qux"))
end

def test_check_until
s = create_string_scanner("Foo Bar Baz")
assert_equal("Foo", s.check_until(/Foo/))
Expand All @@ -693,6 +740,16 @@ def test_check_until
assert_equal(nil, s.check_until(/Qux/))
end

def test_check_until_string
omit("not implemented on TruffleRuby") if RUBY_ENGINE == "truffleruby"
s = create_string_scanner("Foo Bar Baz")
assert_equal("Foo", s.check_until("Foo"))
assert_equal(0, s.pos)
assert_equal("Foo Bar", s.check_until("Bar"))
assert_equal(0, s.pos)
assert_equal(nil, s.check_until("Qux"))
end

def test_search_full
s = create_string_scanner("Foo Bar Baz")
assert_equal(8, s.search_full(/Bar /, false, false))
Expand All @@ -705,6 +762,19 @@ def test_search_full
assert_equal(11, s.pos)
end

def test_search_full_string
omit("not implemented on TruffleRuby") if RUBY_ENGINE == "truffleruby"
s = create_string_scanner("Foo Bar Baz")
assert_equal(8, s.search_full("Bar ", false, false))
assert_equal(0, s.pos)
assert_equal("Foo Bar ", s.search_full("Bar ", false, true))
assert_equal(0, s.pos)
assert_equal(8, s.search_full("Bar ", true, false))
assert_equal(8, s.pos)
assert_equal("Baz", s.search_full("az", true, true))
assert_equal(11, s.pos)
end

def test_peek
s = create_string_scanner("test string")
assert_equal("test st", s.peek(7))
Expand Down