You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
119 lines
2.4 KiB
Lua
119 lines
2.4 KiB
Lua
local si = require "snax.interface"
|
|
|
|
local function envid(f)
|
|
local i = 1
|
|
while true do
|
|
local name, value = debug.getupvalue(f, i)
|
|
if name == nil then
|
|
return
|
|
end
|
|
if name == "_ENV" then
|
|
return debug.upvalueid(f, i)
|
|
end
|
|
i = i + 1
|
|
end
|
|
end
|
|
|
|
local function collect_uv(f , uv, env)
|
|
local i = 1
|
|
while true do
|
|
local name, value = debug.getupvalue(f, i)
|
|
if name == nil then
|
|
break
|
|
end
|
|
local id = debug.upvalueid(f, i)
|
|
|
|
if uv[name] then
|
|
assert(uv[name].id == id, string.format("ambiguity local value %s", name))
|
|
else
|
|
uv[name] = { func = f, index = i, id = id }
|
|
|
|
if type(value) == "function" then
|
|
if envid(value) == env then
|
|
collect_uv(value, uv, env)
|
|
end
|
|
end
|
|
end
|
|
|
|
i = i + 1
|
|
end
|
|
end
|
|
|
|
local function collect_all_uv(funcs)
|
|
local global = {}
|
|
for _, v in pairs(funcs) do
|
|
if v[4] then
|
|
collect_uv(v[4], global, envid(v[4]))
|
|
end
|
|
end
|
|
if not global["_ENV"] then
|
|
global["_ENV"] = {func = collect_uv, index = 1}
|
|
end
|
|
return global
|
|
end
|
|
|
|
local function loader(source)
|
|
return function (path, name, G)
|
|
return load(source, "=patch", "bt", G)
|
|
end
|
|
end
|
|
|
|
local function find_func(funcs, group , name)
|
|
for _, desc in pairs(funcs) do
|
|
local _, g, n = table.unpack(desc)
|
|
if group == g and name == n then
|
|
return desc
|
|
end
|
|
end
|
|
end
|
|
|
|
local dummy_env = {}
|
|
for k,v in pairs(_ENV) do dummy_env[k] = v end
|
|
|
|
local function _patch(global, f)
|
|
local i = 1
|
|
while true do
|
|
local name, value = debug.getupvalue(f, i)
|
|
if name == nil then
|
|
break
|
|
elseif value == nil or value == dummy_env then
|
|
local old_uv = global[name]
|
|
if old_uv then
|
|
debug.upvaluejoin(f, i, old_uv.func, old_uv.index)
|
|
end
|
|
else
|
|
if type(value) == "function" then
|
|
_patch(global, value)
|
|
end
|
|
end
|
|
i = i + 1
|
|
end
|
|
end
|
|
|
|
local function patch_func(funcs, global, group, name, f)
|
|
local desc = assert(find_func(funcs, group, name) , string.format("Patch mismatch %s.%s", group, name))
|
|
_patch(global, f)
|
|
desc[4] = f
|
|
end
|
|
|
|
local function inject(funcs, source, ...)
|
|
local patch = si("patch", dummy_env, loader(source))
|
|
local global = collect_all_uv(funcs)
|
|
|
|
for _, v in pairs(patch) do
|
|
local _, group, name, f = table.unpack(v)
|
|
if f then
|
|
patch_func(funcs, global, group, name, f)
|
|
end
|
|
end
|
|
|
|
local hf = find_func(patch, "system", "hotfix")
|
|
if hf and hf[4] then
|
|
return hf[4](...)
|
|
end
|
|
end
|
|
|
|
return function (funcs, source, ...)
|
|
return pcall(inject, funcs, source, ...)
|
|
end
|