mirror of
https://gerrit.wikimedia.org/r/mediawiki/extensions/Scribunto
synced 2024-11-27 01:30:00 +00:00
cd618c7a92
Lua actually treats a close-bracket at the start of a bracketed character class as a literal, rather than using it to close the character class. Probably unintended behavior, but it happens. Also, have the pure-lua version throw our more informative errors on error even when falling back to string.find and the like, and fix some other weird edge cases that came up in testing. Bug: T95958 Bug: T115686 Change-Id: Iab783d4a3e58b1514cc09729d4a71c2cb1242ee8
1152 lines
28 KiB
Lua
1152 lines
28 KiB
Lua
local ustring = {}
|
||
|
||
-- Copy these, just in case
|
||
local S = {
|
||
byte = string.byte,
|
||
char = string.char,
|
||
len = string.len,
|
||
sub = string.sub,
|
||
find = string.find,
|
||
match = string.match,
|
||
gmatch = string.gmatch,
|
||
gsub = string.gsub,
|
||
format = string.format,
|
||
}
|
||
|
||
---- Configuration ----
|
||
-- To limit the length of strings or patterns processed, set these
|
||
ustring.maxStringLength = math.huge
|
||
ustring.maxPatternLength = math.huge
|
||
|
||
---- Utility functions ----
|
||
|
||
local function checkType( name, argidx, arg, expecttype, nilok )
|
||
if arg == nil and nilok then
|
||
return
|
||
end
|
||
if type( arg ) ~= expecttype then
|
||
local msg = S.format( "bad argument #%d to '%s' (%s expected, got %s)",
|
||
argidx, name, expecttype, type( arg )
|
||
)
|
||
error( msg, 3 )
|
||
end
|
||
end
|
||
|
||
local function checkString( name, s )
|
||
if type( s ) == 'number' then
|
||
s = tostring( s )
|
||
end
|
||
if type( s ) ~= 'string' then
|
||
local msg = S.format( "bad argument #1 to '%s' (string expected, got %s)",
|
||
name, type( s )
|
||
)
|
||
error( msg, 3 )
|
||
end
|
||
if S.len( s ) > ustring.maxStringLength then
|
||
local msg = S.format( "bad argument #1 to '%s' (string is longer than %d bytes)",
|
||
name, ustring.maxStringLength
|
||
)
|
||
error( msg, 3 )
|
||
end
|
||
end
|
||
|
||
local function checkPattern( name, pattern )
|
||
if type( pattern ) == 'number' then
|
||
pattern = tostring( pattern )
|
||
end
|
||
if type( pattern ) ~= 'string' then
|
||
local msg = S.format( "bad argument #2 to '%s' (string expected, got %s)",
|
||
name, type( pattern )
|
||
)
|
||
error( msg, 3 )
|
||
end
|
||
if S.len( pattern ) > ustring.maxPatternLength then
|
||
local msg = S.format( "bad argument #2 to '%s' (pattern is longer than %d bytes)",
|
||
name, ustring.maxPatternLength
|
||
)
|
||
error( msg, 3 )
|
||
end
|
||
end
|
||
|
||
-- A private helper that splits a string into codepoints, and also collects the
|
||
-- starting position of each character and the total length in codepoints.
|
||
--
|
||
-- @param s string utf8-encoded string to decode
|
||
-- @return table
|
||
local function utf8_explode( s )
|
||
local ret = {
|
||
len = 0,
|
||
codepoints = {},
|
||
bytepos = {},
|
||
}
|
||
|
||
local i = 1
|
||
local l = S.len( s )
|
||
local cp, b, b2, trail
|
||
local min
|
||
while i <= l do
|
||
b = S.byte( s, i )
|
||
if b < 0x80 then
|
||
-- 1-byte code point, 00-7F
|
||
cp = b
|
||
trail = 0
|
||
min = 0
|
||
elseif b < 0xc2 then
|
||
-- Either a non-initial code point (invalid here) or
|
||
-- an overlong encoding for a 1-byte code point
|
||
return nil
|
||
elseif b < 0xe0 then
|
||
-- 2-byte code point, C2-DF
|
||
trail = 1
|
||
cp = b - 0xc0
|
||
min = 0x80
|
||
elseif b < 0xf0 then
|
||
-- 3-byte code point, E0-EF
|
||
trail = 2
|
||
cp = b - 0xe0
|
||
min = 0x800
|
||
elseif b < 0xf4 then
|
||
-- 4-byte code point, F0-F3
|
||
trail = 3
|
||
cp = b - 0xf0
|
||
min = 0x10000
|
||
elseif b == 0xf4 then
|
||
-- 4-byte code point, F4
|
||
-- Make sure it doesn't decode to over U+10FFFF
|
||
if S.byte( s, i + 1 ) > 0x8f then
|
||
return nil
|
||
end
|
||
trail = 3
|
||
cp = 4
|
||
min = 0x100000
|
||
else
|
||
-- Code point over U+10FFFF, or invalid byte
|
||
return nil
|
||
end
|
||
|
||
-- Check subsequent bytes for multibyte code points
|
||
for j = i + 1, i + trail do
|
||
b = S.byte( s, j )
|
||
if not b or b < 0x80 or b > 0xbf then
|
||
return nil
|
||
end
|
||
cp = cp * 0x40 + b - 0x80
|
||
end
|
||
if cp < min then
|
||
-- Overlong encoding
|
||
return nil
|
||
end
|
||
|
||
ret.codepoints[#ret.codepoints + 1] = cp
|
||
ret.bytepos[#ret.bytepos + 1] = i
|
||
ret.len = ret.len + 1
|
||
i = i + 1 + trail
|
||
end
|
||
|
||
-- Two past the end (for sub with empty string)
|
||
ret.bytepos[#ret.bytepos + 1] = l + 1
|
||
ret.bytepos[#ret.bytepos + 1] = l + 1
|
||
|
||
return ret
|
||
end
|
||
|
||
-- A private helper that finds the character offset for a byte offset.
|
||
--
|
||
-- @param cps table from utf8_explode
|
||
-- @param i int byte offset
|
||
-- @return int
|
||
local function cpoffset( cps, i )
|
||
local min, max, p = 0, cps.len + 1
|
||
if i == 0 then
|
||
return 0
|
||
end
|
||
while min + 1 < max do
|
||
p = math.floor( ( min + max ) / 2 ) + 1
|
||
if cps.bytepos[p] <= i then
|
||
min = p - 1
|
||
end
|
||
if cps.bytepos[p] >= i then
|
||
max = p - 1
|
||
end
|
||
end
|
||
return min + 1
|
||
end
|
||
|
||
---- Trivial functions ----
|
||
-- These functions are the same as the standard string versions
|
||
|
||
ustring.byte = string.byte
|
||
ustring.format = string.format
|
||
ustring.rep = string.rep
|
||
|
||
---- Non-trivial functions ----
|
||
-- These functions actually have to be UTF-8 aware
|
||
|
||
|
||
-- Determine if a string is valid UTF-8
|
||
--
|
||
-- @param s string
|
||
-- @return boolean
|
||
function ustring.isutf8( s )
|
||
checkString( 'isutf8', s )
|
||
return utf8_explode( s ) ~= nil
|
||
end
|
||
|
||
-- Return the byte offset of a character in a string
|
||
--
|
||
-- @param s string
|
||
-- @param l int codepoint number [default 1]
|
||
-- @param i int starting byte offset [default 1]
|
||
-- @return int|nil
|
||
function ustring.byteoffset( s, l, i )
|
||
checkString( 'byteoffset', s )
|
||
checkType( 'byteoffset', 2, l, 'number', true )
|
||
checkType( 'byteoffset', 3, i, 'number', true )
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
error( "bad argument #1 for 'byteoffset' (string is not UTF-8)", 2 )
|
||
end
|
||
|
||
i = i or 1
|
||
if i < 0 then
|
||
i = S.len( s ) + i + 1
|
||
end
|
||
if i < 1 or i > S.len( s ) then
|
||
return nil
|
||
end
|
||
local p = cpoffset( cps, i )
|
||
if l > 0 and cps.bytepos[p] == i then
|
||
l = l - 1
|
||
end
|
||
if p + l > cps.len then
|
||
return nil
|
||
end
|
||
return cps.bytepos[p + l]
|
||
end
|
||
|
||
-- Return codepoints from a string
|
||
--
|
||
-- @see string.byte
|
||
-- @param s string
|
||
-- @param i int Starting character [default 1]
|
||
-- @param j int Ending character [default i]
|
||
-- @return int* Zero or more codepoints
|
||
function ustring.codepoint( s, i, j )
|
||
checkString( 'codepoint', s )
|
||
checkType( 'codepoint', 2, i, 'number', true )
|
||
checkType( 'codepoint', 3, j, 'number', true )
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
error( "bad argument #1 for 'codepoint' (string is not UTF-8)", 2 )
|
||
end
|
||
i = i or 1
|
||
if i < 0 then
|
||
i = cps.len + i + 1
|
||
end
|
||
j = j or i
|
||
if j < 0 then
|
||
j = cps.len + j + 1
|
||
end
|
||
if j < i then
|
||
return -- empty result set
|
||
end
|
||
i = math.max( 1, math.min( i, cps.len + 1 ) )
|
||
j = math.max( 1, math.min( j, cps.len + 1 ) )
|
||
return unpack( cps.codepoints, i, j )
|
||
end
|
||
|
||
-- Return an iterator over the codepoint (as integers)
|
||
-- for cp in ustring.gcodepoint( s ) do ... end
|
||
--
|
||
-- @param s string
|
||
-- @param i int Starting character [default 1]
|
||
-- @param j int Ending character [default -1]
|
||
-- @return function
|
||
-- @return nil
|
||
-- @return nil
|
||
function ustring.gcodepoint( s, i, j )
|
||
checkString( 'gcodepoint', s )
|
||
checkType( 'gcodepoint', 2, i, 'number', true )
|
||
checkType( 'gcodepoint', 3, j, 'number', true )
|
||
local cp = { ustring.codepoint( s, i or 1, j or -1 ) }
|
||
return function ()
|
||
return table.remove( cp, 1 )
|
||
end
|
||
end
|
||
|
||
-- Convert codepoints to a string
|
||
--
|
||
-- @see string.char
|
||
-- @param ... int List of codepoints
|
||
-- @return string
|
||
local function internalChar( t, s, e )
|
||
local ret = {}
|
||
for i = s, e do
|
||
local v = t[i]
|
||
if type( v ) ~= 'number' then
|
||
checkType( 'char', i, v, 'number' )
|
||
end
|
||
v = math.floor( v )
|
||
if v < 0 or v > 0x10ffff then
|
||
error( S.format( "bad argument #%d to 'char' (value out of range)", i ), 2 )
|
||
elseif v < 0x80 then
|
||
ret[#ret + 1] = v
|
||
elseif v < 0x800 then
|
||
ret[#ret + 1] = 0xc0 + math.floor( v / 0x40 ) % 0x20
|
||
ret[#ret + 1] = 0x80 + v % 0x40
|
||
elseif v < 0x10000 then
|
||
ret[#ret + 1] = 0xe0 + math.floor( v / 0x1000 ) % 0x10
|
||
ret[#ret + 1] = 0x80 + math.floor( v / 0x40 ) % 0x40
|
||
ret[#ret + 1] = 0x80 + v % 0x40
|
||
else
|
||
ret[#ret + 1] = 0xf0 + math.floor( v / 0x40000 ) % 0x08
|
||
ret[#ret + 1] = 0x80 + math.floor( v / 0x1000 ) % 0x40
|
||
ret[#ret + 1] = 0x80 + math.floor( v / 0x40 ) % 0x40
|
||
ret[#ret + 1] = 0x80 + v % 0x40
|
||
end
|
||
end
|
||
return S.char( unpack( ret ) )
|
||
end
|
||
function ustring.char( ... )
|
||
return internalChar( { ... }, 1, select( '#', ... ) )
|
||
end
|
||
|
||
-- Return the length of a string in codepoints, or
|
||
-- nil if the string is not valid UTF-8.
|
||
--
|
||
-- @see string.len
|
||
-- @param string
|
||
-- @return int|nil
|
||
function ustring.len( s )
|
||
checkString( 'len', s )
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
return nil
|
||
else
|
||
return cps.len
|
||
end
|
||
end
|
||
|
||
-- Private function to return a substring of a string
|
||
--
|
||
-- @param s string
|
||
-- @param cps table Exploded string
|
||
-- @param i int Starting character [default 1]
|
||
-- @param j int Ending character [default -1]
|
||
-- @return string
|
||
local function sub( s, cps, i, j )
|
||
return S.sub( s, cps.bytepos[i], cps.bytepos[j+1] - 1 )
|
||
end
|
||
|
||
-- Return a substring of a string
|
||
--
|
||
-- @see string.sub
|
||
-- @param s string
|
||
-- @param i int Starting character [default 1]
|
||
-- @param j int Ending character [default -1]
|
||
-- @return string
|
||
function ustring.sub( s, i, j )
|
||
checkString( 'sub', s )
|
||
checkType( 'sub', 2, i, 'number', true )
|
||
checkType( 'sub', 3, j, 'number', true )
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
error( "bad argument #1 for 'sub' (string is not UTF-8)", 2 )
|
||
end
|
||
i = i or 1
|
||
if i < 0 then
|
||
i = cps.len + i + 1
|
||
end
|
||
j = j or -1
|
||
if j < 0 then
|
||
j = cps.len + j + 1
|
||
end
|
||
if j < i then
|
||
return ''
|
||
end
|
||
i = math.max( 1, math.min( i, cps.len + 1 ) )
|
||
j = math.max( 1, math.min( j, cps.len + 1 ) )
|
||
return sub( s, cps, i, j )
|
||
end
|
||
|
||
---- Table-driven functions ----
|
||
-- These functions load a conversion table when called
|
||
|
||
-- Convert a string to uppercase
|
||
--
|
||
-- @see string.upper
|
||
-- @param s string
|
||
-- @return string
|
||
function ustring.upper( s )
|
||
checkString( 'upper', s )
|
||
local map = require 'ustring/upper';
|
||
local ret = S.gsub( s, '([^\128-\191][\128-\191]*)', map )
|
||
return ret
|
||
end
|
||
|
||
-- Convert a string to lowercase
|
||
--
|
||
-- @see string.lower
|
||
-- @param s string
|
||
-- @return string
|
||
function ustring.lower( s )
|
||
checkString( 'lower', s )
|
||
local map = require 'ustring/lower';
|
||
local ret = S.gsub( s, '([^\128-\191][\128-\191]*)', map )
|
||
return ret
|
||
end
|
||
|
||
---- Pattern functions ----
|
||
-- Ugh. Just ugh.
|
||
|
||
-- Cache for character sets (e.g. [a-z])
|
||
local charset_cache = {}
|
||
setmetatable( charset_cache, { __weak = 'kv' } )
|
||
|
||
-- Private function to find a pattern in a string
|
||
-- Yes, this basically reimplements the whole of Lua's pattern matching, in
|
||
-- Lua.
|
||
--
|
||
-- @see ustring.find
|
||
-- @param s string
|
||
-- @param cps table Exploded string
|
||
-- @param rawpat string Pattern
|
||
-- @param pattern table Exploded pattern
|
||
-- @param init int Starting index
|
||
-- @param noAnchor boolean True to ignore '^'
|
||
-- @return int starting index of the match
|
||
-- @return int ending index of the match
|
||
-- @return string|int* captures
|
||
local function find( s, cps, rawpat, pattern, init, noAnchor )
|
||
local charsets = require 'ustring/charsets'
|
||
local anchor = false
|
||
local ncapt, captures
|
||
local captparen = {}
|
||
|
||
-- Extract the value of a capture from the
|
||
-- upvalues ncapt and capture.
|
||
local function getcapt( n, err, errl )
|
||
if n > ncapt then
|
||
error( err, errl + 1 )
|
||
elseif type( captures[n] ) == 'table' then
|
||
if captures[n][2] == '' then
|
||
error( err, errl + 1 )
|
||
end
|
||
return sub( s, cps, captures[n][1], captures[n][2] ), captures[n][2] - captures[n][1] + 1
|
||
else
|
||
return captures[n], math.floor( math.log10( captures[n] ) ) + 1
|
||
end
|
||
end
|
||
|
||
local match, match_charset, parse_charset
|
||
|
||
-- Main matching function. Uses tail recursion where possible.
|
||
-- Returns the position of the character after the match, and updates the
|
||
-- upvalues ncapt and captures.
|
||
match = function ( sp, pp )
|
||
local c = pattern.codepoints[pp]
|
||
if c == 0x28 then -- '(': starts capture group
|
||
ncapt = ncapt + 1
|
||
captparen[ncapt] = pp
|
||
local ret
|
||
if pattern.codepoints[pp + 1] == 0x29 then -- ')': Pattern is '()', capture position
|
||
captures[ncapt] = sp
|
||
ret = match( sp, pp + 2 )
|
||
else
|
||
-- Start capture group
|
||
captures[ncapt] = { sp, '' }
|
||
ret = match( sp, pp + 1 )
|
||
end
|
||
if ret then
|
||
return ret
|
||
else
|
||
-- Failed, rollback
|
||
ncapt = ncapt - 1
|
||
return nil
|
||
end
|
||
elseif c == 0x29 then -- ')': ends capture group, pop current capture index from stack
|
||
for n = ncapt, 1, -1 do
|
||
if type( captures[n] ) == 'table' and captures[n][2] == '' then
|
||
captures[n][2] = sp - 1
|
||
local ret = match( sp, pp + 1 )
|
||
if ret then
|
||
return ret
|
||
else
|
||
-- Failed, rollback
|
||
captures[n][2] = ''
|
||
return nil
|
||
end
|
||
end
|
||
end
|
||
error( 'Unmatched close-paren at pattern character ' .. pp, 3 )
|
||
elseif c == 0x5b then -- '[': starts character set
|
||
return match_charset( sp, parse_charset( pp ) )
|
||
elseif c == 0x5d then -- ']'
|
||
error( 'Unmatched close-bracket at pattern character ' .. pp, 3 )
|
||
elseif c == 0x25 then -- '%'
|
||
c = pattern.codepoints[pp + 1]
|
||
if charsets[c] then -- A character set like '%a'
|
||
return match_charset( sp, pp + 2, charsets[c] )
|
||
elseif c == 0x62 then -- '%b': balanced delimiter match
|
||
local d1 = pattern.codepoints[pp + 2]
|
||
local d2 = pattern.codepoints[pp + 3]
|
||
if not d1 or not d2 then
|
||
error( 'malformed pattern (missing arguments to \'%b\')', 3 )
|
||
end
|
||
if cps.codepoints[sp] ~= d1 then
|
||
return nil
|
||
end
|
||
sp = sp + 1
|
||
local ct = 1
|
||
while true do
|
||
c = cps.codepoints[sp]
|
||
sp = sp + 1
|
||
if not c then
|
||
return nil
|
||
elseif c == d2 then
|
||
if ct == 1 then
|
||
return match( sp, pp + 4 )
|
||
end
|
||
ct = ct - 1
|
||
elseif c == d1 then
|
||
ct = ct + 1
|
||
end
|
||
end
|
||
elseif c == 0x66 then -- '%f': frontier pattern match
|
||
if pattern.codepoints[pp + 2] ~= 0x5b then
|
||
error( 'missing \'[\' after %f in pattern at pattern character ' .. pp, 3 )
|
||
end
|
||
local pp, charset = parse_charset( pp + 2 )
|
||
local c1 = cps.codepoints[sp - 1] or 0
|
||
local c2 = cps.codepoints[sp] or 0
|
||
if not charset[c1] and charset[c2] then
|
||
return match( sp, pp )
|
||
else
|
||
return nil
|
||
end
|
||
elseif c >= 0x30 and c <= 0x39 then -- '%0' to '%9': backreference
|
||
local m, l = getcapt( c - 0x30, 'invalid capture index %' .. c .. ' at pattern character ' .. pp, 3 )
|
||
local ep = math.min( cps.len + 1, sp + l )
|
||
if sub( s, cps, sp, ep - 1 ) == m then
|
||
return match( ep, pp + 2 )
|
||
else
|
||
return nil
|
||
end
|
||
elseif not c then -- percent at the end of the pattern
|
||
error( 'malformed pattern (ends with \'%\')', 3 )
|
||
else -- something else, treat as a literal
|
||
return match_charset( sp, pp + 2, { [c] = 1 } )
|
||
end
|
||
elseif c == 0x2e then -- '.': match anything
|
||
if not charset_cache['.'] then
|
||
local t = {}
|
||
setmetatable( t, { __index = function ( t, k ) return k end } )
|
||
charset_cache['.'] = { 1, t }
|
||
end
|
||
return match_charset( sp, pp + 1, charset_cache['.'][2] )
|
||
elseif c == nil then -- end of pattern
|
||
return sp
|
||
elseif c == 0x24 and pattern.len == pp then -- '$': assert end of string
|
||
return ( sp == cps.len + 1 ) and sp or nil
|
||
else
|
||
-- Any other character matches itself
|
||
return match_charset( sp, pp + 1, { [c] = 1 } )
|
||
end
|
||
end
|
||
|
||
-- Parse a bracketed character set (e.g. [a-z])
|
||
-- Returns the position after the set and a table holding the matching characters
|
||
parse_charset = function ( pp )
|
||
local _, ep
|
||
local epp = pattern.bytepos[pp] + 1
|
||
if S.sub( rawpat, epp, epp ) == '^' then
|
||
epp = epp + 1
|
||
end
|
||
if S.sub( rawpat, epp, epp ) == ']' then
|
||
-- Lua's string module effectively does this
|
||
epp = epp + 1
|
||
end
|
||
repeat
|
||
_, ep = S.find( rawpat, ']', epp, true )
|
||
if not ep then
|
||
error( 'Missing close-bracket for character set beginning at pattern character ' .. pp, 3 )
|
||
end
|
||
epp = ep + 1
|
||
until S.byte( rawpat, ep - 1 ) ~= 0x25 or S.byte( rawpat, ep - 2 ) == 0x25
|
||
local key = S.sub( rawpat, pattern.bytepos[pp], ep )
|
||
if charset_cache[key] then
|
||
local pl, cs = unpack( charset_cache[key] )
|
||
return pp + pl, cs
|
||
end
|
||
|
||
local p0 = pp
|
||
local cs = {}
|
||
local csrefs = { cs }
|
||
local invert = false
|
||
pp = pp + 1
|
||
if pattern.codepoints[pp] == 0x5e then -- '^'
|
||
invert = true
|
||
pp = pp + 1
|
||
end
|
||
local first = true
|
||
while true do
|
||
local c = pattern.codepoints[pp]
|
||
if not first and c == 0x5d then -- closing ']'
|
||
pp = pp + 1
|
||
break
|
||
elseif c == 0x25 then -- '%'
|
||
c = pattern.codepoints[pp + 1]
|
||
if charsets[c] then
|
||
csrefs[#csrefs + 1] = charsets[c]
|
||
else
|
||
cs[c] = 1
|
||
end
|
||
pp = pp + 2
|
||
elseif pattern.codepoints[pp + 1] == 0x2d and pattern.codepoints[pp + 2] and pattern.codepoints[pp + 2] ~= 0x5d then -- '-' followed by another char (not ']'), it's a range
|
||
for i = c, pattern.codepoints[pp + 2] do
|
||
cs[i] = 1
|
||
end
|
||
pp = pp + 3
|
||
elseif not c then -- Should never get here, but Just In Case...
|
||
error( 'Missing close-bracket', 3 )
|
||
else
|
||
cs[c] = 1
|
||
pp = pp + 1
|
||
end
|
||
first = false
|
||
end
|
||
|
||
local ret
|
||
if not csrefs[2] then
|
||
if not invert then
|
||
-- If there's only the one charset table, we can use it directly
|
||
ret = cs
|
||
else
|
||
-- Simple invert
|
||
ret = {}
|
||
setmetatable( ret, { __index = function ( t, k ) return k and not cs[k] end } )
|
||
end
|
||
else
|
||
-- Ok, we have to iterate over multiple charset tables
|
||
ret = {}
|
||
setmetatable( ret, { __index = function ( t, k )
|
||
if not k then
|
||
return nil
|
||
end
|
||
for i = 1, #csrefs do
|
||
if csrefs[i][k] then
|
||
return not invert
|
||
end
|
||
end
|
||
return invert
|
||
end } )
|
||
end
|
||
|
||
charset_cache[key] = { pp - p0, ret }
|
||
return pp, ret
|
||
end
|
||
|
||
-- Match a character set table with optional quantifier, followed by
|
||
-- the rest of the pattern.
|
||
-- Returns same as 'match' above.
|
||
match_charset = function ( sp, pp, charset )
|
||
local q = pattern.codepoints[pp]
|
||
if q == 0x2a then -- '*', 0 or more matches
|
||
pp = pp + 1
|
||
local i = 0
|
||
while charset[cps.codepoints[sp + i]] do
|
||
i = i + 1
|
||
end
|
||
while i >= 0 do
|
||
local ret = match( sp + i, pp )
|
||
if ret then
|
||
return ret
|
||
end
|
||
i = i - 1
|
||
end
|
||
return nil
|
||
elseif q == 0x2b then -- '+', 1 or more matches
|
||
pp = pp + 1
|
||
local i = 0
|
||
while charset[cps.codepoints[sp + i]] do
|
||
i = i + 1
|
||
end
|
||
while i > 0 do
|
||
local ret = match( sp + i, pp )
|
||
if ret then
|
||
return ret
|
||
end
|
||
i = i - 1
|
||
end
|
||
return nil
|
||
elseif q == 0x2d then -- '-', 0 or more matches non-greedy
|
||
pp = pp + 1
|
||
while true do
|
||
local ret = match( sp, pp )
|
||
if ret then
|
||
return ret
|
||
end
|
||
if not charset[cps.codepoints[sp]] then
|
||
return nil
|
||
end
|
||
sp = sp + 1
|
||
end
|
||
elseif q == 0x3f then -- '?', 0 or 1 match
|
||
pp = pp + 1
|
||
if charset[cps.codepoints[sp]] then
|
||
local ret = match( sp + 1, pp )
|
||
if ret then
|
||
return ret
|
||
end
|
||
end
|
||
return match( sp, pp )
|
||
else -- no suffix, must match 1
|
||
if charset[cps.codepoints[sp]] then
|
||
return match( sp + 1, pp )
|
||
else
|
||
return nil
|
||
end
|
||
end
|
||
end
|
||
|
||
init = init or 1
|
||
if init < 0 then
|
||
init = cps.len + init + 1
|
||
end
|
||
init = math.max( 1, math.min( init, cps.len + 1 ) )
|
||
|
||
-- Here is the actual match loop. It just calls 'match' on successive
|
||
-- starting positions (or not, if the pattern is anchored) until it finds a
|
||
-- match.
|
||
local sp = init
|
||
local pp = 1
|
||
if not noAnchor and pattern.codepoints[1] == 0x5e then -- '^': Pattern is anchored
|
||
anchor = true
|
||
pp = 2
|
||
end
|
||
|
||
repeat
|
||
ncapt, captures = 0, {}
|
||
local ep = match( sp, pp )
|
||
if ep then
|
||
for i = 1, ncapt do
|
||
captures[i] = getcapt( i, 'Unclosed capture beginning at pattern character ' .. captparen[pp], 2 )
|
||
end
|
||
return sp, ep - 1, unpack( captures )
|
||
end
|
||
sp = sp + 1
|
||
until anchor or sp > cps.len + 1
|
||
return nil
|
||
end
|
||
|
||
-- Private function to decide if a pattern looks simple enough to use
|
||
-- Lua's built-in string library. The following make a pattern not simple:
|
||
-- * If it contains any bytes over 0x7f. We could skip these if they're not
|
||
-- inside brackets and aren't followed by quantifiers and aren't part of a
|
||
-- '%b', but that's too complicated to check.
|
||
-- * If it contains a negated character set.
|
||
-- * If it contains "%a" or any of the other %-prefixed character sets except %z.
|
||
-- * If it contains a '.' not followed by '*', '+', '-'. A bare '.' or '.?'
|
||
-- matches a partial UTF-8 character, but the others will happily enough
|
||
-- match a whole UTF-8 character thinking it's 2, 3 or 4.
|
||
-- * If it contains position-captures.
|
||
--
|
||
-- @param string pattern
|
||
-- @return boolean
|
||
local function patternIsSimple( pattern )
|
||
return not (
|
||
S.find( pattern, '[\128-\255]' ) or
|
||
S.find( pattern, '%[%^' ) or
|
||
S.find( pattern, '%%[acdlpsuwxACDLPSUWXZ]' ) or
|
||
S.find( pattern, '%.[^*+-]' ) or S.find( pattern, '%.$' ) or
|
||
S.find( pattern, '()', 1, true )
|
||
)
|
||
end
|
||
|
||
-- Find a pattern in a string
|
||
--
|
||
-- This works just like string.find, with the following changes:
|
||
-- * Everything works on UTF-8 characters rather than bytes
|
||
-- * Character classes are redefined in terms of Unicode properties:
|
||
-- * %a - Letter
|
||
-- * %c - Control
|
||
-- * %d - Decimal Number
|
||
-- * %l - Lower case letter
|
||
-- * %p - Punctuation
|
||
-- * %s - Separator, plus HT, LF, FF, CR, and VT
|
||
-- * %u - Upper case letter
|
||
-- * %w - Letter or Decimal Number
|
||
-- * %x - [0-9A-Fa-f0-9A-Fa-f]
|
||
--
|
||
-- @see string.find
|
||
-- @param s string
|
||
-- @param pattern string Pattern
|
||
-- @param init int Starting index
|
||
-- @param plain boolean Literal match, no pattern matching
|
||
-- @return int starting index of the match
|
||
-- @return int ending index of the match
|
||
-- @return string|int* captures
|
||
function ustring.find( s, pattern, init, plain )
|
||
checkString( 'find', s )
|
||
checkPattern( 'find', pattern )
|
||
checkType( 'find', 3, init, 'number', true )
|
||
checkType( 'find', 4, plain, 'boolean', true )
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
error( "bad argument #1 for 'find' (string is not UTF-8)", 2 )
|
||
end
|
||
local pat = utf8_explode( pattern )
|
||
if pat == nil then
|
||
error( "bad argument #2 for 'find' (string is not UTF-8)", 2 )
|
||
end
|
||
|
||
if plain or patternIsSimple( pattern ) then
|
||
if init and init > cps.len + 1 then
|
||
init = cps.len + 1
|
||
end
|
||
local m
|
||
if plain then
|
||
m = { true, S.find( s, pattern, cps.bytepos[init], plain ) }
|
||
else
|
||
m = { pcall( S.find, s, pattern, cps.bytepos[init], plain ) }
|
||
end
|
||
if m[1] then
|
||
if m[2] then
|
||
m[2] = cpoffset( cps, m[2] )
|
||
m[3] = cpoffset( cps, m[3] )
|
||
end
|
||
return unpack( m, 2 )
|
||
end
|
||
end
|
||
|
||
return find( s, cps, pattern, pat, init )
|
||
end
|
||
|
||
-- Match a string against a pattern
|
||
--
|
||
-- @see ustring.find
|
||
-- @see string.match
|
||
-- @param s string
|
||
-- @param pattern string
|
||
-- @param init int Starting offset for match
|
||
-- @return string|int* captures, or the whole match if there are none
|
||
function ustring.match( s, pattern, init )
|
||
checkString( 'match', s )
|
||
checkPattern( 'match', pattern )
|
||
checkType( 'match', 3, init, 'number', true )
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
error( "bad argument #1 for 'match' (string is not UTF-8)", 2 )
|
||
end
|
||
local pat = utf8_explode( pattern )
|
||
if pat == nil then
|
||
error( "bad argument #2 for 'match' (string is not UTF-8)", 2 )
|
||
end
|
||
|
||
if patternIsSimple( pattern ) then
|
||
local ret = { pcall( S.match, s, pattern, cps.bytepos[init] ) }
|
||
if ret[1] then
|
||
return unpack( ret, 2 )
|
||
end
|
||
end
|
||
|
||
local m = { find( s, cps, pattern, pat, init ) }
|
||
if not m[1] then
|
||
return nil
|
||
end
|
||
if m[3] then
|
||
return unpack( m, 3 )
|
||
end
|
||
return sub( s, cps, m[1], m[2] )
|
||
end
|
||
|
||
-- Return an iterator function over the matches for a pattern
|
||
--
|
||
-- @see ustring.find
|
||
-- @see string.gmatch
|
||
-- @param s string
|
||
-- @param pattern string
|
||
-- @return function
|
||
-- @return nil
|
||
-- @return nil
|
||
function ustring.gmatch( s, pattern )
|
||
checkString( 'gmatch', s )
|
||
checkPattern( 'gmatch', pattern )
|
||
if patternIsSimple( pattern ) then
|
||
local ret = { pcall( S.gmatch, s, pattern ) }
|
||
if ret[1] then
|
||
return unpack( ret, 2 )
|
||
end
|
||
end
|
||
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
error( "bad argument #1 for 'gmatch' (string is not UTF-8)", 2 )
|
||
end
|
||
local pat = utf8_explode( pattern )
|
||
if pat == nil then
|
||
error( "bad argument #2 for 'gmatch' (string is not UTF-8)", 2 )
|
||
end
|
||
local init = 1
|
||
|
||
return function ()
|
||
local m = { find( s, cps, pattern, pat, init, true ) }
|
||
if not m[1] then
|
||
return nil
|
||
end
|
||
init = m[2] + 1
|
||
if m[3] then
|
||
return unpack( m, 3 )
|
||
end
|
||
return sub( s, cps, m[1], m[2] )
|
||
end
|
||
end
|
||
|
||
-- Replace pattern matches in a string
|
||
--
|
||
-- @see ustring.find
|
||
-- @see string.gsub
|
||
-- @param s string
|
||
-- @param pattern string
|
||
-- @param repl string|function|table
|
||
-- @param int n
|
||
-- @return string
|
||
-- @return int
|
||
function ustring.gsub( s, pattern, repl, n )
|
||
checkString( 'gsub', s )
|
||
checkPattern( 'gsub', pattern )
|
||
checkType( 'gsub', 4, n, 'number', true )
|
||
if patternIsSimple( pattern ) then
|
||
local ret = { pcall( S.gsub, s, pattern, repl, n ) }
|
||
if ret[1] then
|
||
return unpack( ret, 2 )
|
||
end
|
||
end
|
||
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
error( "bad argument #1 for 'gsub' (string is not UTF-8)", 2 )
|
||
end
|
||
local pat = utf8_explode( pattern )
|
||
if pat == nil then
|
||
error( "bad argument #2 for 'gsub' (string is not UTF-8)", 2 )
|
||
end
|
||
if n == nil then
|
||
n = 1e100
|
||
end
|
||
|
||
if pat.codepoints[1] == 0x5e then -- '^': Pattern is anchored
|
||
-- There can be only the one match, so make that explicit
|
||
n = 1
|
||
end
|
||
|
||
local tp
|
||
if type( repl ) == 'function' then
|
||
tp = 1
|
||
elseif type( repl ) == 'table' then
|
||
tp = 2
|
||
elseif type( repl ) == 'string' then
|
||
tp = 3
|
||
elseif type( repl ) == 'number' then
|
||
repl = tostring( repl )
|
||
tp = 3
|
||
else
|
||
checkType( 'gsub', 3, repl, 'function or table or string' )
|
||
end
|
||
|
||
local init = 1
|
||
local ct = 0
|
||
local ret = {}
|
||
while init < cps.len + 1 and ct < n do
|
||
local m = { find( s, cps, pattern, pat, init ) }
|
||
if not m[1] then
|
||
break
|
||
end
|
||
if init < m[1] then
|
||
ret[#ret + 1] = sub( s, cps, init, m[1] - 1 )
|
||
end
|
||
local mm = sub( s, cps, m[1], m[2] )
|
||
local val
|
||
if tp == 1 then
|
||
if m[3] then
|
||
val = repl( unpack( m, 3 ) )
|
||
else
|
||
val = repl( mm )
|
||
end
|
||
elseif tp == 2 then
|
||
val = repl[m[3] or mm]
|
||
elseif tp == 3 then
|
||
if ct == 0 and #m < 11 then
|
||
local ss = S.gsub( repl, '%%[%%0-' .. ( #m - 2 ) .. ']', 'x' )
|
||
ss = S.match( ss, '%%[0-9]' )
|
||
if ss then
|
||
error( 'invalid capture index ' .. ss .. ' in replacement string', 2 )
|
||
end
|
||
end
|
||
local t = {
|
||
["%0"] = mm,
|
||
["%1"] = m[3],
|
||
["%2"] = m[4],
|
||
["%3"] = m[5],
|
||
["%4"] = m[6],
|
||
["%5"] = m[7],
|
||
["%6"] = m[8],
|
||
["%7"] = m[9],
|
||
["%8"] = m[10],
|
||
["%9"] = m[11],
|
||
["%%"] = "%"
|
||
}
|
||
val = S.gsub( repl, '%%[%%0-9]', t )
|
||
end
|
||
ret[#ret + 1] = val or mm
|
||
init = m[2] + 1
|
||
ct = ct + 1
|
||
end
|
||
if init <= cps.len then
|
||
ret[#ret + 1] = sub( s, cps, init, cps.len )
|
||
end
|
||
return table.concat( ret ), ct
|
||
end
|
||
|
||
---- Unicode Normalization ----
|
||
-- These functions load a conversion table when called
|
||
|
||
local function internalToNFD( cps )
|
||
local cp = {}
|
||
local normal = require 'ustring/normalization-data'
|
||
|
||
-- Decompose into cp, using the lookup table and logic for hangul
|
||
for i = 1, cps.len do
|
||
local c = cps.codepoints[i]
|
||
local m = normal.decomp[c]
|
||
if m then
|
||
for j = 0, #m do
|
||
cp[#cp + 1] = m[j]
|
||
end
|
||
else
|
||
cp[#cp + 1] = c
|
||
end
|
||
end
|
||
|
||
-- Now sort combiners by class
|
||
local i, l = 1, #cp
|
||
while i < l do
|
||
local cc1 = normal.combclass[cp[i]]
|
||
local cc2 = normal.combclass[cp[i+1]]
|
||
if cc1 and cc2 and cc1 > cc2 then
|
||
cp[i], cp[i+1] = cp[i+1], cp[i]
|
||
if i > 1 then
|
||
i = i - 1
|
||
else
|
||
i = i + 1
|
||
end
|
||
else
|
||
i = i + 1
|
||
end
|
||
end
|
||
|
||
return cp, 1, l
|
||
end
|
||
|
||
-- Normalize a string to NFC
|
||
--
|
||
-- Based on MediaWiki's UtfNormal class. Returns nil if the string is not valid
|
||
-- UTF-8.
|
||
--
|
||
-- @param s string
|
||
-- @return string|nil
|
||
function ustring.toNFC( s )
|
||
checkString( 'toNFC', s )
|
||
|
||
-- ASCII is always NFC
|
||
if not S.find( s, '[\128-\255]' ) then
|
||
return s
|
||
end
|
||
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
return nil
|
||
end
|
||
local normal = require 'ustring/normalization-data'
|
||
|
||
-- First, scan through to see if the string is definitely already NFC
|
||
local ok = true
|
||
for i = 1, cps.len do
|
||
local c = cps.codepoints[i]
|
||
if normal.check[c] then
|
||
ok = false
|
||
break
|
||
end
|
||
end
|
||
if ok then
|
||
return s
|
||
end
|
||
|
||
-- Next, expand to NFD
|
||
local cp, _, l = internalToNFD( cps )
|
||
|
||
-- Then combine to NFC. Since NFD->NFC can never expand a character
|
||
-- sequence, we can do this in-place.
|
||
local comp = normal.comp[cp[1]]
|
||
local sc = 1
|
||
local j = 1
|
||
local lastclass = 0
|
||
for i = 2, l do
|
||
local c = cp[i]
|
||
local ccc = normal.combclass[c]
|
||
if ccc then
|
||
-- Trying a combiner with the starter
|
||
if comp and lastclass < ccc and comp[c] then
|
||
-- Yes!
|
||
c = comp[c]
|
||
cp[sc] = c
|
||
comp = normal.comp[c]
|
||
else
|
||
-- No, copy it to the right place for output
|
||
j = j + 1
|
||
cp[j] = c
|
||
lastclass = ccc
|
||
end
|
||
elseif comp and lastclass == 0 and comp[c] then
|
||
-- Combining two adjacent starters
|
||
c = comp[c]
|
||
cp[sc] = c
|
||
comp = normal.comp[c]
|
||
else
|
||
-- New starter, doesn't combine
|
||
j = j + 1
|
||
cp[j] = c
|
||
comp = normal.comp[c]
|
||
sc = j
|
||
lastclass = 0
|
||
end
|
||
end
|
||
|
||
return internalChar( cp, 1, j )
|
||
end
|
||
|
||
-- Normalize a string to NFD
|
||
--
|
||
-- Based on MediaWiki's UtfNormal class. Returns nil if the string is not valid
|
||
-- UTF-8.
|
||
--
|
||
-- @param s string
|
||
-- @return string|nil
|
||
function ustring.toNFD( s )
|
||
checkString( 'toNFD', s )
|
||
|
||
-- ASCII is always NFC
|
||
if not S.find( s, '[\128-\255]' ) then
|
||
return s
|
||
end
|
||
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
return nil
|
||
end
|
||
|
||
return internalChar( internalToNFD( cps ) )
|
||
end
|
||
|
||
return ustring
|