mirror of
https://gerrit.wikimedia.org/r/mediawiki/extensions/Scribunto
synced 2024-11-25 08:36:21 +00:00
0a8757baba
This is a reimplementation of Lua's string library with support for UTF-8. The entire ustring library is implemented in pure Lua. PHP callbacks are also available for overrides: in LuaSandbox these are used for almost all functions, while in LuaStandalone they are used only for the pattern matching. Also, ustring.upper and ustring.lower are overridden using mw.language's .uc and .lc if available. It also includes a bunch of unit tests. Note that if you download the normalization tests, they may fail under LuaSandbox if you have PHP's intl extension installed and libicu on your system is too old. Change-Id: Ie76fdf8d3a85d0a3d2a41b0d3b7afe433f247af0
1102 lines
27 KiB
Lua
1102 lines
27 KiB
Lua
local ustring = {}
|
||
|
||
-- Copy these, just in case
|
||
local S = {
|
||
byte = string.byte,
|
||
char = string.char,
|
||
len = string.len,
|
||
sub = string.sub,
|
||
find = string.find,
|
||
match = string.match,
|
||
gmatch = string.gmatch,
|
||
gsub = string.gsub,
|
||
format = string.format,
|
||
}
|
||
|
||
---- Configuration ----
|
||
-- To limit the length of strings or patterns processed, set these
|
||
ustring.maxStringLength = inf
|
||
ustring.maxPatternLength = inf
|
||
|
||
---- Utility functions ----
|
||
|
||
local function checkType( name, argidx, arg, expecttype, nilok )
|
||
if arg == nil and nilok then
|
||
return
|
||
end
|
||
if type( arg ) ~= expecttype then
|
||
local msg = S.format( "bad argument #%d to '%s' (%s expected, got %s)",
|
||
argidx, name, expecttype, type( arg )
|
||
)
|
||
error( msg, 3 )
|
||
end
|
||
end
|
||
|
||
local function checkString( name, s )
|
||
if type( s ) ~= 'string' then
|
||
local msg = S.format( "bad argument #1 to '%s' (string expected, got %s)",
|
||
name, type( s )
|
||
)
|
||
error( msg, 3 )
|
||
end
|
||
if S.len( s ) > ustring.maxStringLength then
|
||
local msg = S.format( "bad argument #1 to '%s' (string is longer than %d bytes)",
|
||
name, ustring.maxStringLength
|
||
)
|
||
error( msg, 3 )
|
||
end
|
||
end
|
||
|
||
local function checkPattern( name, pattern )
|
||
if type( pattern ) ~= 'string' then
|
||
local msg = S.format( "bad argument #2 to '%s' (string expected, got %s)",
|
||
name, type( pattern )
|
||
)
|
||
error( msg, 3 )
|
||
end
|
||
if S.len( pattern ) > ustring.maxPatternLength then
|
||
local msg = S.format( "bad argument #2 to '%s' (pattern is longer than %d bytes)",
|
||
name, ustring.maxPatternLength
|
||
)
|
||
error( msg, 3 )
|
||
end
|
||
end
|
||
|
||
-- A private helper that splits a string into codepoints, and also collects the
|
||
-- starting position of each character and the total length in codepoints.
|
||
--
|
||
-- @param s string utf8-encoded string to decode
|
||
-- @return table
|
||
local function utf8_explode( s )
|
||
local ret = {
|
||
len = 0,
|
||
codepoints = {},
|
||
bytepos = {},
|
||
}
|
||
|
||
local i = 1
|
||
local l = S.len( s )
|
||
local cp, b, b2, trail
|
||
local min
|
||
while i <= l do
|
||
b = S.byte( s, i )
|
||
if b < 0x80 then
|
||
-- 1-byte code point, 00-7F
|
||
cp = b
|
||
trail = 0
|
||
min = 0
|
||
elseif b < 0xc2 then
|
||
-- Either a non-initial code point (invalid here) or
|
||
-- an overlong encoding for a 1-byte code point
|
||
return nil
|
||
elseif b < 0xe0 then
|
||
-- 2-byte code point, C2-DF
|
||
trail = 1
|
||
cp = b - 0xc0
|
||
min = 0x80
|
||
elseif b < 0xf0 then
|
||
-- 3-byte code point, E0-EF
|
||
trail = 2
|
||
cp = b - 0xe0
|
||
min = 0x800
|
||
elseif b < 0xf4 then
|
||
-- 4-byte code point, F0-F3
|
||
trail = 3
|
||
cp = b - 0xf0
|
||
min = 0x10000
|
||
elseif b == 0xf4 then
|
||
-- 4-byte code point, F4
|
||
-- Make sure it doesn't decode to over U+10FFFF
|
||
if S.byte( s, i + 1 ) > 0x8f then
|
||
return nil
|
||
end
|
||
trail = 3
|
||
cp = 4
|
||
min = 0x100000
|
||
else
|
||
-- Code point over U+10FFFF, or invalid byte
|
||
return nil
|
||
end
|
||
|
||
-- Check subsequent bytes for multibyte code points
|
||
for j = i + 1, i + trail do
|
||
b = S.byte( s, j )
|
||
if not b or b < 0x80 or b > 0xbf then
|
||
return nil
|
||
end
|
||
cp = cp * 0x40 + b - 0x80
|
||
end
|
||
if cp < min then
|
||
-- Overlong encoding
|
||
return nil
|
||
end
|
||
|
||
ret.codepoints[#ret.codepoints + 1] = cp
|
||
ret.bytepos[#ret.bytepos + 1] = i
|
||
ret.len = ret.len + 1
|
||
i = i + 1 + trail
|
||
end
|
||
|
||
-- One past the end
|
||
ret.bytepos[#ret.bytepos + 1] = l + 1
|
||
|
||
return ret
|
||
end
|
||
|
||
-- A private helper that finds the character offset for a byte offset.
|
||
--
|
||
-- @param cps table from utf8_explode
|
||
-- @param i int byte offset
|
||
-- @return int
|
||
local function cpoffset( cps, i )
|
||
local min, max, p = 0, cps.len + 1
|
||
while min + 1 < max do
|
||
p = math.floor( ( min + max ) / 2 ) + 1
|
||
if cps.bytepos[p] <= i then
|
||
min = p - 1
|
||
end
|
||
if cps.bytepos[p] >= i then
|
||
max = p - 1
|
||
end
|
||
end
|
||
return min + 1
|
||
end
|
||
|
||
---- Trivial functions ----
|
||
-- These functions are the same as the standard string versions
|
||
|
||
ustring.byte = string.byte
|
||
ustring.format = string.format
|
||
ustring.rep = string.rep
|
||
|
||
---- Non-trivial functions ----
|
||
-- These functions actually have to be UTF-8 aware
|
||
|
||
|
||
-- Determine if a string is valid UTF-8
|
||
--
|
||
-- @param s string
|
||
-- @return boolean
|
||
function ustring.isutf8( s )
|
||
checkString( 'isutf8', s )
|
||
return utf8_explode( s ) ~= nil
|
||
end
|
||
|
||
-- Return the byte offset of a character in a string
|
||
--
|
||
-- @param s string
|
||
-- @param l int codepoint number [default 1]
|
||
-- @param i int starting byte offset [default 1]
|
||
-- @return int|nil
|
||
function ustring.byteoffset( s, l, i )
|
||
checkString( 'byteoffset', s )
|
||
checkType( 'byteoffset', 2, l, 'number', true )
|
||
checkType( 'byteoffset', 3, i, 'number', true )
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
error( "bad argument #1 for 'byteoffset' (string is not UTF-8)", 2 )
|
||
end
|
||
|
||
i = i or 1
|
||
if i < 0 then
|
||
i = S.len( s ) + i + 1
|
||
end
|
||
if i < 1 or i > S.len( s ) then
|
||
return nil
|
||
end
|
||
local p = cpoffset( cps, i )
|
||
if l > 0 and cps.bytepos[p] == i then
|
||
l = l - 1
|
||
end
|
||
if p + l > cps.len then
|
||
return nil
|
||
end
|
||
return cps.bytepos[p + l]
|
||
end
|
||
|
||
-- Return codepoints from a string
|
||
--
|
||
-- @see string.byte
|
||
-- @param s string
|
||
-- @param i int Starting character [default 1]
|
||
-- @param j int Ending character [default i]
|
||
-- @return int* Zero or more codepoints
|
||
function ustring.codepoint( s, i, j )
|
||
checkString( 'codepoint', s )
|
||
checkType( 'codepoint', 2, i, 'number', true )
|
||
checkType( 'codepoint', 3, j, 'number', true )
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
error( "bad argument #1 for 'codepoint' (string is not UTF-8)", 2 )
|
||
end
|
||
i = i or 1
|
||
if i < 0 then
|
||
i = cps.len + i + 1
|
||
end
|
||
j = j or i
|
||
if j < 0 then
|
||
j = cps.len + j + 1
|
||
end
|
||
i = math.max( 1, math.min( i, cps.len ) )
|
||
j = math.max( 1, math.min( j, cps.len ) )
|
||
return unpack( cps.codepoints, i, j )
|
||
end
|
||
|
||
-- Return an iterator over the codepoint (as integers)
|
||
-- for cp in ustring.gcodepoint( s ) do ... end
|
||
--
|
||
-- @param s string
|
||
-- @param i int Starting character [default 1]
|
||
-- @param j int Ending character [default -1]
|
||
-- @return function
|
||
-- @return nil
|
||
-- @return nil
|
||
function ustring.gcodepoint( s, i, j )
|
||
checkString( 'gcodepoint', s )
|
||
checkType( 'gcodepoint', 2, i, 'number', true )
|
||
checkType( 'gcodepoint', 3, j, 'number', true )
|
||
local cp = { ustring.codepoint( s, i or 1, j or -1 ) }
|
||
return function ()
|
||
return table.remove( cp, 1 )
|
||
end
|
||
end
|
||
|
||
-- Convert codepoints to a string
|
||
--
|
||
-- @see string.char
|
||
-- @param ... int List of codepoints
|
||
-- @return string
|
||
local function internalChar( t, s, e )
|
||
local ret = {}
|
||
for i = s, e do
|
||
local v = t[i]
|
||
if type( v ) ~= 'number' then
|
||
checkType( 'char', i, v, 'number' )
|
||
end
|
||
v = math.floor( v )
|
||
if v < 0 or v > 0x10ffff then
|
||
error( S.format( "bad argument #%d to 'char' (value out of range)", i ), 2 )
|
||
elseif v < 0x80 then
|
||
ret[#ret + 1] = v
|
||
elseif v < 0x800 then
|
||
ret[#ret + 1] = 0xc0 + math.floor( v / 0x40 ) % 0x20
|
||
ret[#ret + 1] = 0x80 + v % 0x40
|
||
elseif v < 0x10000 then
|
||
ret[#ret + 1] = 0xe0 + math.floor( v / 0x1000 ) % 0x10
|
||
ret[#ret + 1] = 0x80 + math.floor( v / 0x40 ) % 0x40
|
||
ret[#ret + 1] = 0x80 + v % 0x40
|
||
else
|
||
ret[#ret + 1] = 0xf0 + math.floor( v / 0x40000 ) % 0x08
|
||
ret[#ret + 1] = 0x80 + math.floor( v / 0x1000 ) % 0x40
|
||
ret[#ret + 1] = 0x80 + math.floor( v / 0x40 ) % 0x40
|
||
ret[#ret + 1] = 0x80 + v % 0x40
|
||
end
|
||
end
|
||
return S.char( unpack( ret ) )
|
||
end
|
||
function ustring.char( ... )
|
||
return internalChar( { ... }, 1, select( '#', ... ) )
|
||
end
|
||
|
||
-- Return the length of a string in codepoints, or
|
||
-- nil if the string is not valid UTF-8.
|
||
--
|
||
-- @see string.len
|
||
-- @param string
|
||
-- @return int|nil
|
||
function ustring.len( s )
|
||
checkString( 'len', s )
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
return nil
|
||
else
|
||
return cps.len
|
||
end
|
||
end
|
||
|
||
-- Private function to return a substring of a string
|
||
--
|
||
-- @param s string
|
||
-- @param cps table Exploded string
|
||
-- @param i int Starting character [default 1]
|
||
-- @param j int Ending character [default -1]
|
||
-- @return string
|
||
local function sub( s, cps, i, j )
|
||
return S.sub( s, cps.bytepos[i], cps.bytepos[j+1] - 1 )
|
||
end
|
||
|
||
-- Return a substring of a string
|
||
--
|
||
-- @see string.sub
|
||
-- @param s string
|
||
-- @param i int Starting character [default 1]
|
||
-- @param j int Ending character [default -1]
|
||
-- @return string
|
||
function ustring.sub( s, i, j )
|
||
checkString( 'sub', s )
|
||
checkType( 'sub', 2, i, 'number', true )
|
||
checkType( 'sub', 3, j, 'number', true )
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
error( "bad argument #1 for 'sub' (string is not UTF-8)", 2 )
|
||
end
|
||
i = i or 1
|
||
if i < 0 then
|
||
i = cps.len + i + 1
|
||
end
|
||
j = j or -1
|
||
if j < 0 then
|
||
j = cps.len + j + 1
|
||
end
|
||
i = math.max( 1, math.min( i, cps.len + 1 ) )
|
||
j = math.max( 1, math.min( j, cps.len + 1 ) )
|
||
return sub( s, cps, i, j )
|
||
end
|
||
|
||
---- Table-driven functions ----
|
||
-- These functions load a conversion table when called
|
||
|
||
-- Convert a string to uppercase
|
||
--
|
||
-- @see string.upper
|
||
-- @param s string
|
||
-- @return string
|
||
function ustring.upper( s )
|
||
checkString( 'upper', s )
|
||
local map = require 'ustring/upper';
|
||
local ret = S.gsub( s, '([^\128-\191][\128-\191]*)', map )
|
||
return ret
|
||
end
|
||
|
||
-- Convert a string to lowercase
|
||
--
|
||
-- @see string.lower
|
||
-- @param s string
|
||
-- @return string
|
||
function ustring.lower( s )
|
||
checkString( 'lower', s )
|
||
local map = require 'ustring/lower';
|
||
local ret = S.gsub( s, '([^\128-\191][\128-\191]*)', map )
|
||
return ret
|
||
end
|
||
|
||
---- Pattern functions ----
|
||
-- Ugh. Just ugh.
|
||
|
||
-- Cache for character sets (e.g. [a-z])
|
||
local charset_cache = {}
|
||
setmetatable( charset_cache, { __weak = 'kv' } )
|
||
|
||
-- Private function to find a pattern in a string
|
||
-- Yes, this basically reimplements the whole of Lua's pattern matching, in
|
||
-- Lua.
|
||
--
|
||
-- @see ustring.find
|
||
-- @param s string
|
||
-- @param cps table Exploded string
|
||
-- @param rawpat string Pattern
|
||
-- @param pattern table Exploded pattern
|
||
-- @param init int Starting index
|
||
-- @return int starting index of the match
|
||
-- @return int ending index of the match
|
||
-- @return string|int* captures
|
||
local function find( s, cps, rawpat, pattern, init )
|
||
local charsets = require 'ustring/charsets'
|
||
local anchor = false
|
||
local ncapt, captures
|
||
local captparen = {}
|
||
|
||
-- Extract the value of a capture from the
|
||
-- upvalues ncapt and capture.
|
||
local function getcapt( n, err, errl )
|
||
if n > ncapt then
|
||
error( err, errl + 1 )
|
||
elseif type( captures[n] ) == 'table' then
|
||
if captures[n][2] == '' then
|
||
error( err, errl + 1 )
|
||
end
|
||
return sub( s, cps, captures[n][1], captures[n][2] ), captures[n][2] - captures[n][1] + 1
|
||
else
|
||
return captures[n], math.floor( math.log10( captures[n] ) ) + 1
|
||
end
|
||
end
|
||
|
||
local match, match_charset, parse_charset
|
||
|
||
-- Main matching function. Uses tail recursion where possible.
|
||
-- Returns the position of the character after the match, and updates the
|
||
-- upvalues ncapt and captures.
|
||
match = function ( sp, pp )
|
||
local c = pattern.codepoints[pp]
|
||
if c == 0x28 then -- '(': starts capture group
|
||
ncapt = ncapt + 1
|
||
captparen[ncapt] = pp
|
||
local ret
|
||
if pattern.codepoints[pp + 1] == 0x29 then -- ')': Pattern is '()', capture position
|
||
captures[ncapt] = sp
|
||
ret = match( sp, pp + 2 )
|
||
else
|
||
-- Start capture group
|
||
captures[ncapt] = { sp, '' }
|
||
ret = match( sp, pp + 1 )
|
||
end
|
||
if ret then
|
||
return ret
|
||
else
|
||
-- Failed, rollback
|
||
ncapt = ncapt - 1
|
||
return nil
|
||
end
|
||
elseif c == 0x29 then -- ')': ends capture group, pop current capture index from stack
|
||
for n = ncapt, 1, -1 do
|
||
if type( captures[n] ) == 'table' and captures[n][2] == '' then
|
||
captures[n][2] = sp - 1
|
||
local ret = match( sp, pp + 1 )
|
||
if ret then
|
||
return ret
|
||
else
|
||
-- Failed, rollback
|
||
captures[n][2] = ''
|
||
return nil
|
||
end
|
||
end
|
||
end
|
||
error( 'Unmatched close-paren at pattern character ' .. pp, 3 )
|
||
elseif c == 0x5b then -- '[': starts character set
|
||
return match_charset( sp, parse_charset( pp ) )
|
||
elseif c == 0x5d then -- ']'
|
||
error( 'Unmatched close-bracket at pattern character ' .. pp, 3 )
|
||
elseif c == 0x25 then -- '%'
|
||
c = pattern.codepoints[pp + 1]
|
||
if charsets[c] then -- A character set like '%a'
|
||
return match_charset( sp, pp + 2, charsets[c] )
|
||
elseif c == 0x62 then -- '%b': balanced delimiter match
|
||
d1 = pattern.codepoints[pp + 2]
|
||
d2 = pattern.codepoints[pp + 3]
|
||
if not d1 or not d2 then
|
||
error( 'malformed pattern (missing arguments to \'%b\')', 3 )
|
||
end
|
||
if cps.codepoints[sp] ~= d1 then
|
||
return nil
|
||
end
|
||
sp = sp + 1
|
||
local ct = 1
|
||
while true do
|
||
c = cps.codepoints[sp]
|
||
sp = sp + 1
|
||
if not c then
|
||
return nil
|
||
elseif c == d2 then
|
||
if ct == 1 then
|
||
return match( sp, pp + 4 )
|
||
end
|
||
ct = ct - 1
|
||
elseif c == d1 then
|
||
ct = ct + 1
|
||
end
|
||
end
|
||
elseif c >= 0x30 and c <= 0x39 then -- '%0' to '%9': backreference
|
||
local m, l = getcapt( c - 0x30, 'invalid capture index %' .. c .. ' at pattern character ' .. pp, 3 )
|
||
local ep = math.min( cps.len + 1, sp + l )
|
||
if sub( s, cps, sp, ep - 1 ) == m then
|
||
return match( ep, pp + 2 )
|
||
else
|
||
return nil
|
||
end
|
||
elseif not c then -- percent at the end of the pattern
|
||
error( 'malformed pattern (ends with \'%\')', 3 )
|
||
else -- something else, treat as a literal
|
||
return match_charset( sp, pp + 2, { [c] = 1 } )
|
||
end
|
||
elseif c == 0x2e then -- '.': match anything
|
||
if not charset_cache['.'] then
|
||
local t = {}
|
||
setmetatable( t, { __index = function ( t, k ) return k end } )
|
||
charset_cache['.'] = { 1, t }
|
||
end
|
||
return match_charset( sp, pp + 1, charset_cache['.'][2] )
|
||
elseif c == nil then -- end of pattern
|
||
return sp
|
||
elseif c == 0x24 and pattern.len == pp then -- '$': assert end of string
|
||
return ( sp == cps.len + 1 ) and sp or nil
|
||
else
|
||
-- Any other character matches itself
|
||
return match_charset( sp, pp + 1, { [c] = 1 } )
|
||
end
|
||
end
|
||
|
||
-- Parse a bracketed character set (e.g. [a-z])
|
||
-- Returns the position after the set and a table holding the matching characters
|
||
parse_charset = function ( pp )
|
||
local _, ep
|
||
local epp = pattern.bytepos[pos]
|
||
repeat
|
||
_, ep = S.find( rawpat, ']', epp, true )
|
||
if not ep then
|
||
error( 'Missing close-bracket for character set beginning at pattern character ' .. pp, 3 )
|
||
end
|
||
epp = ep + 1
|
||
until S.byte( rawpat, ep - 1 ) ~= 0x25 or S.byte( rawpat, ep - 2 ) == 0x25
|
||
local key = S.sub( rawpat, pattern.bytepos[pp], ep )
|
||
if charset_cache[key] then
|
||
local pl, cs = unpack( charset_cache[key] )
|
||
return pp + pl, cs
|
||
end
|
||
|
||
local p0 = pp
|
||
local cs = {}
|
||
local csrefs = { cs }
|
||
local invert = false
|
||
pp = pp + 1
|
||
if pattern.codepoints[pp] == 0x5e then -- '^'
|
||
invert = true
|
||
pp = pp + 1
|
||
end
|
||
while true do
|
||
local c = pattern.codepoints[pp]
|
||
if c == 0x25 then -- '%'
|
||
c = pattern.codepoints[pp + 1]
|
||
if charsets[c] then
|
||
csrefs[#csrefs + 1] = charsets[c]
|
||
else
|
||
cs[c] = 1
|
||
end
|
||
pp = pp + 2
|
||
elseif pattern.codepoints[pp + 1] == 0x2d and pattern.codepoints[pp + 2] and pattern.codepoints[pp + 2] ~= 0x5d then -- '-' followed by another char (not ']'), it's a range
|
||
for i = c, pattern.codepoints[pp + 2] do
|
||
cs[i] = 1
|
||
end
|
||
pp = pp + 3
|
||
elseif c == 0x5d then -- closing ']'
|
||
pp = pp + 1
|
||
break
|
||
elseif not c then -- Should never get here, but Just In Case...
|
||
error( 'Missing close-bracket', 3 )
|
||
else
|
||
cs[c] = 1
|
||
pp = pp + 1
|
||
end
|
||
end
|
||
|
||
local ret
|
||
if not csrefs[2] then
|
||
if not invert then
|
||
-- If there's only the one charset table, we can use it directly
|
||
ret = cs
|
||
else
|
||
-- Simple invert
|
||
ret = {}
|
||
setmetatable( ret, { __index = function ( t, k ) return k and not cs[k] end } )
|
||
end
|
||
else
|
||
-- Ok, we have to iterate over multiple charset tables
|
||
ret = {}
|
||
setmetatable( ret, { __index = function ( t, k )
|
||
if not k then
|
||
return nil
|
||
end
|
||
for i = 1, #csrefs do
|
||
if csrefs[i][k] then
|
||
return not invert
|
||
end
|
||
end
|
||
return invert
|
||
end } )
|
||
end
|
||
|
||
charset_cache[key] = { pp - p0, ret }
|
||
return pp, ret
|
||
end
|
||
|
||
-- Match a character set table with optional quantifier, followed by
|
||
-- the rest of the pattern.
|
||
-- Returns same as 'match' above.
|
||
match_charset = function ( sp, pp, charset )
|
||
local q = pattern.codepoints[pp]
|
||
if q == 0x2a then -- '*', 0 or more matches
|
||
pp = pp + 1
|
||
local i = 0
|
||
while charset[cps.codepoints[sp + i]] do
|
||
i = i + 1
|
||
end
|
||
while i >= 0 do
|
||
local ret = match( sp + i, pp )
|
||
if ret then
|
||
return ret
|
||
end
|
||
i = i - 1
|
||
end
|
||
return nil
|
||
elseif q == 0x2b then -- '+', 1 or more matches
|
||
pp = pp + 1
|
||
local i = 0
|
||
while charset[cps.codepoints[sp + i]] do
|
||
i = i + 1
|
||
end
|
||
while i > 0 do
|
||
local ret = match( sp + i, pp )
|
||
if ret then
|
||
return ret
|
||
end
|
||
i = i - 1
|
||
end
|
||
return nil
|
||
elseif q == 0x2d then -- '-', 0 or more matches non-greedy
|
||
pp = pp + 1
|
||
while true do
|
||
local ret = match( sp, pp )
|
||
if ret then
|
||
return ret
|
||
end
|
||
if not charset[cps.codepoints[sp]] then
|
||
return nil
|
||
end
|
||
sp = sp + 1
|
||
end
|
||
elseif q == 0x3f then -- '?', 0 or 1 match
|
||
pp = pp + 1
|
||
if charset[cps.codepoints[sp]] then
|
||
local ret = match( sp + 1, pp )
|
||
if ret then
|
||
return ret
|
||
end
|
||
end
|
||
return match( sp, pp )
|
||
else -- no suffix, must match 1
|
||
if charset[cps.codepoints[sp]] then
|
||
return match( sp + 1, pp )
|
||
else
|
||
return nil
|
||
end
|
||
end
|
||
end
|
||
|
||
-- Here is the actual match loop. It just calls 'match' on successive
|
||
-- starting positions (or not, if the pattern is anchored) until it finds a
|
||
-- match.
|
||
local sp = init
|
||
local pp = 1
|
||
if pattern.codepoints[1] == 0x5e then -- '^': Pattern is anchored
|
||
anchor = true
|
||
pp = 2
|
||
end
|
||
|
||
repeat
|
||
ncapt, captures = 0, {}
|
||
local ep = match( sp, pp )
|
||
if ep then
|
||
for i = 1, ncapt do
|
||
captures[i] = getcapt( i, 'Unclosed capture beginning at pattern character ' .. captparen[pp], 2 )
|
||
end
|
||
return sp, ep - 1, unpack( captures )
|
||
end
|
||
sp = sp + 1
|
||
until anchor or sp > cps.len
|
||
return nil
|
||
end
|
||
|
||
-- Private function to decide if a pattern looks simple enough to use
|
||
-- Lua's built-in string library. The following make a pattern not simple:
|
||
-- * If it contains any bytes over 0x7f. We could skip these if they're not
|
||
-- inside brackets and aren't followed by quantifiers and aren't part of a
|
||
-- '%b', but that's too complicated to check.
|
||
-- * If it contains "%a" or any of the other %-prefixed character sets except
|
||
-- %z or %Z.
|
||
-- * If it contains a '.' not followed by '*', '+', or '-'. A bare '.' or '.?'
|
||
-- would try to match a partial UTF-8 character, but the others will happily
|
||
-- enough match a whole character thinking it's 2 or 4.
|
||
-- * If it contains position-captures.
|
||
--
|
||
-- @param string pattern
|
||
-- @return boolean
|
||
local function patternIsSimple( pattern )
|
||
return not (
|
||
S.find( pattern, '[\128-\255]' ) or
|
||
S.find( pattern, '%%[acdlpsuwxACDLPSUWX]' ) or
|
||
S.find( pattern, '%.[^*+-]' ) or
|
||
S.find( pattern, '()', 1, true )
|
||
)
|
||
end
|
||
|
||
-- Find a pattern in a string
|
||
--
|
||
-- This works just like string.find, with the following changes:
|
||
-- * Everything works on UTF-8 characters rather than bytes
|
||
-- * Character classes are redefined in terms of Unicode properties:
|
||
-- * %a - Letter
|
||
-- * %c - Control
|
||
-- * %d - Decimal Number
|
||
-- * %l - Lower case letter
|
||
-- * %p - Punctuation
|
||
-- * %s - Separator, plus HT, LF, FF, CR, and VT
|
||
-- * %u - Upper case letter
|
||
-- * %w - Letter or Decimal Number
|
||
-- * %x - [0-9A-Fa-f0-9A-Fa-f]
|
||
--
|
||
-- @see string.find
|
||
-- @param s string
|
||
-- @param pattern string Pattern
|
||
-- @param init int Starting index
|
||
-- @param plain boolean Literal match, no pattern matching
|
||
-- @return int starting index of the match
|
||
-- @return int ending index of the match
|
||
-- @return string|int* captures
|
||
function ustring.find( s, pattern, init, plain )
|
||
checkString( 'find', s )
|
||
checkPattern( 'find', pattern )
|
||
checkType( 'find', 3, init, 'number', true )
|
||
checkType( 'find', 4, plain, 'boolean', true )
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
error( "bad argument #1 for 'find' (string is not UTF-8)", 2 )
|
||
end
|
||
local pat = utf8_explode( pattern )
|
||
if pat == nil then
|
||
error( "bad argument #2 for 'find' (string is not UTF-8)", 2 )
|
||
end
|
||
|
||
if plain or patternIsSimple( pattern ) then
|
||
local m = { S.find( s, pattern, cps.bytepos[init], plain ) }
|
||
if m[1] then
|
||
m[1] = cpoffset( cps, m[1] )
|
||
m[2] = cpoffset( cps, m[2] )
|
||
end
|
||
return unpack( m )
|
||
end
|
||
|
||
init = init or 1
|
||
if init < 0 then
|
||
init = cps.len + init + 1
|
||
end
|
||
|
||
return find( s, cps, pattern, pat, init )
|
||
end
|
||
|
||
-- Match a string against a pattern
|
||
--
|
||
-- @see ustring.find
|
||
-- @see string.match
|
||
-- @param s string
|
||
-- @param pattern string
|
||
-- @param init int Starting offset for match
|
||
-- @return string|int* captures, or the whole match if there are none
|
||
function ustring.match( s, pattern, init )
|
||
checkString( 'match', s )
|
||
checkPattern( 'match', pattern )
|
||
checkType( 'match', 3, init, 'number', true )
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
error( "bad argument #1 for 'match' (string is not UTF-8)", 2 )
|
||
end
|
||
local pat = utf8_explode( pattern )
|
||
if pat == nil then
|
||
error( "bad argument #2 for 'match' (string is not UTF-8)", 2 )
|
||
end
|
||
|
||
if patternIsSimple( pattern ) then
|
||
return S.match( s, pattern, cps.bytepos[init] )
|
||
end
|
||
|
||
init = init or 1
|
||
if init < 0 then
|
||
init = cps.len + init + 1
|
||
end
|
||
|
||
local m = { find( s, cps, pattern, pat, init ) }
|
||
if not m[1] then
|
||
return nil
|
||
end
|
||
if m[3] then
|
||
return unpack( m, 3 )
|
||
end
|
||
return sub( s, cps, m[1], m[2] )
|
||
end
|
||
|
||
-- Return an iterator function over the matches for a pattern
|
||
--
|
||
-- @see ustring.find
|
||
-- @see string.gmatch
|
||
-- @param s string
|
||
-- @param pattern string
|
||
-- @return function
|
||
-- @return nil
|
||
-- @return nil
|
||
function ustring.gmatch( s, pattern )
|
||
checkString( 'gmatch', s )
|
||
checkPattern( 'gmatch', pattern )
|
||
if patternIsSimple( pattern ) then
|
||
return S.gmatch( s, pattern )
|
||
end
|
||
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
error( "bad argument #1 for 'gmatch' (string is not UTF-8)", 2 )
|
||
end
|
||
local pat = utf8_explode( pattern )
|
||
if pat == nil then
|
||
error( "bad argument #2 for 'gmatch' (string is not UTF-8)", 2 )
|
||
end
|
||
local init = 1
|
||
|
||
if pat.codepoints[1] == 0x5e then -- '^': Pattern is anchored
|
||
-- Lua special-cases this to never match
|
||
return function ()
|
||
return nil
|
||
end
|
||
end
|
||
|
||
return function ()
|
||
local m = { find( s, cps, pattern, pat, init ) }
|
||
if not m[1] then
|
||
return nil
|
||
end
|
||
init = m[2] + 1
|
||
if m[3] then
|
||
return unpack( m, 3 )
|
||
end
|
||
return sub( s, cps, m[1], m[2] )
|
||
end
|
||
end
|
||
|
||
-- Replace pattern matches in a string
|
||
--
|
||
-- @see ustring.find
|
||
-- @see string.gsub
|
||
-- @param s string
|
||
-- @param pattern string
|
||
-- @param repl string|function|table
|
||
-- @param int n
|
||
-- @return string
|
||
-- @return int
|
||
function ustring.gsub( s, pattern, repl, n )
|
||
checkString( 'gsub', s )
|
||
checkPattern( 'gsub', pattern )
|
||
checkType( 'gsub', 4, n, 'number', true )
|
||
if patternIsSimple( pattern ) then
|
||
return S.gsub( s, pattern, repl, n )
|
||
end
|
||
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
error( "bad argument #1 for 'gsub' (string is not UTF-8)", 2 )
|
||
end
|
||
local pat = utf8_explode( pattern )
|
||
if pat == nil then
|
||
error( "bad argument #2 for 'gsub' (string is not UTF-8)", 2 )
|
||
end
|
||
if n == nil then
|
||
n = 1e100
|
||
end
|
||
|
||
if pat.codepoints[1] == 0x5e then -- '^': Pattern is anchored
|
||
-- There can be only the one match, so make that explicit
|
||
n = 1
|
||
end
|
||
|
||
local tp
|
||
if type( repl ) == 'function' then
|
||
tp = 1
|
||
elseif type( repl ) == 'table' then
|
||
tp = 2
|
||
elseif type( repl ) == 'string' then
|
||
tp = 3
|
||
else
|
||
checkType( 'gsub', 3, repl, 'function or table or string' )
|
||
end
|
||
|
||
local init = 1
|
||
local ct = 0
|
||
local ret = {}
|
||
while init < cps.len and ct < n do
|
||
local m = { find( s, cps, pattern, pat, init ) }
|
||
if not m[1] then
|
||
break
|
||
end
|
||
if init < m[1] then
|
||
ret[#ret + 1] = sub( s, cps, init, m[1] - 1 )
|
||
end
|
||
local mm = sub( s, cps, m[1], m[2] )
|
||
local val
|
||
if tp == 1 then
|
||
if m[3] then
|
||
val = repl( unpack( m, 3 ) )
|
||
else
|
||
val = repl( mm )
|
||
end
|
||
elseif tp == 2 then
|
||
val = repl[m[3] or mm]
|
||
elseif tp == 3 then
|
||
if ct == 0 and #m < 11 then
|
||
local ss = S.gsub( repl, '%%[%%0-' .. ( #m - 2 ) .. ']', 'x' )
|
||
ss = S.match( ss, '%%[0-9]' )
|
||
if ss then
|
||
error( 'invalid capture index ' .. ss .. ' in replacement string', 2 )
|
||
end
|
||
end
|
||
local t = {
|
||
["%0"] = mm,
|
||
["%1"] = m[3],
|
||
["%2"] = m[4],
|
||
["%3"] = m[5],
|
||
["%4"] = m[6],
|
||
["%5"] = m[7],
|
||
["%6"] = m[8],
|
||
["%7"] = m[9],
|
||
["%8"] = m[10],
|
||
["%9"] = m[11],
|
||
["%%"] = "%"
|
||
}
|
||
val = S.gsub( repl, '%%[%%0-9]', t )
|
||
end
|
||
ret[#ret + 1] = val or mm
|
||
init = m[2] + 1
|
||
ct = ct + 1
|
||
end
|
||
if init <= cps.len then
|
||
ret[#ret + 1] = sub( s, cps, init, cps.len )
|
||
end
|
||
return table.concat( ret ), ct
|
||
end
|
||
|
||
---- Unicode Normalization ----
|
||
-- These functions load a conversion table when called
|
||
|
||
local function internalToNFD( cps )
|
||
local cp = {}
|
||
local normal = require 'ustring/normalization-data'
|
||
|
||
-- Decompose into cp, using the lookup table and logic for hangul
|
||
for i = 1, cps.len do
|
||
local c = cps.codepoints[i]
|
||
local m = normal.decomp[c]
|
||
if m then
|
||
for j = 0, #m do
|
||
cp[#cp + 1] = m[j]
|
||
end
|
||
else
|
||
cp[#cp + 1] = c
|
||
end
|
||
end
|
||
|
||
-- Now sort combiners by class
|
||
local i, l = 1, #cp
|
||
while i < l do
|
||
local cc1 = normal.combclass[cp[i]]
|
||
local cc2 = normal.combclass[cp[i+1]]
|
||
if cc1 and cc2 and cc1 > cc2 then
|
||
cp[i], cp[i+1] = cp[i+1], cp[i]
|
||
if i > 1 then
|
||
i = i - 1
|
||
else
|
||
i = i + 1
|
||
end
|
||
else
|
||
i = i + 1
|
||
end
|
||
end
|
||
|
||
return cp, 1, l
|
||
end
|
||
|
||
-- Normalize a string to NFC
|
||
--
|
||
-- Based on MediaWiki's UtfNormal class. Returns nil if the string is not valid
|
||
-- UTF-8.
|
||
--
|
||
-- @param s string
|
||
-- @return string|nil
|
||
function ustring.toNFC( s )
|
||
checkString( 'toNFC', s )
|
||
|
||
-- ASCII is always NFC
|
||
if not S.find( s, '[\128-\255]' ) then
|
||
return s
|
||
end
|
||
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
return nil
|
||
end
|
||
local normal = require 'ustring/normalization-data'
|
||
|
||
-- First, scan through to see if the string is definitely already NFC
|
||
local ok = true
|
||
for i = 1, cps.len do
|
||
local c = cps.codepoints[i]
|
||
if normal.check[c] then
|
||
ok = false
|
||
break
|
||
end
|
||
end
|
||
if ok then
|
||
return s
|
||
end
|
||
|
||
-- Next, expand to NFD
|
||
local cp, _, l = internalToNFD( cps )
|
||
|
||
-- Then combine to NFC. Since NFD->NFC can never expand a character
|
||
-- sequence, we can do this in-place.
|
||
local comp = normal.comp[cp[1]]
|
||
local sc = 1
|
||
local j = 1
|
||
local lastclass = 0
|
||
for i = 2, l do
|
||
local c = cp[i]
|
||
local ccc = normal.combclass[c]
|
||
if ccc then
|
||
-- Trying a combiner with the starter
|
||
if comp and lastclass < ccc and comp[c] then
|
||
-- Yes!
|
||
c = comp[c]
|
||
cp[sc] = c
|
||
comp = normal.comp[c]
|
||
else
|
||
-- No, copy it to the right place for output
|
||
j = j + 1
|
||
cp[j] = c
|
||
lastclass = ccc
|
||
end
|
||
elseif comp and lastclass == 0 and comp[c] then
|
||
-- Combining two adjacent starters
|
||
c = comp[c]
|
||
cp[sc] = c
|
||
comp = normal.comp[c]
|
||
else
|
||
-- New starter, doesn't combine
|
||
j = j + 1
|
||
cp[j] = c
|
||
comp = normal.comp[c]
|
||
sc = j
|
||
lastclass = 0
|
||
end
|
||
end
|
||
|
||
return internalChar( cp, 1, j )
|
||
end
|
||
|
||
-- Normalize a string to NFD
|
||
--
|
||
-- Based on MediaWiki's UtfNormal class. Returns nil if the string is not valid
|
||
-- UTF-8.
|
||
--
|
||
-- @param s string
|
||
-- @return string|nil
|
||
function ustring.toNFD( s )
|
||
checkString( 'toNFD', s )
|
||
|
||
-- ASCII is always NFC
|
||
if not S.find( s, '[\128-\255]' ) then
|
||
return s
|
||
end
|
||
|
||
local cps = utf8_explode( s )
|
||
if cps == nil then
|
||
return nil
|
||
end
|
||
|
||
return internalChar( internalToNFD( cps ) )
|
||
end
|
||
|
||
return ustring
|