You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

182 lines
3.8 KiB
Lua

local skynet = require "skynet"
local netpack = require "skynet.netpack"
local socketdriver = require "skynet.socketdriver"
local gateserver = {}
local socket -- listen socket
local queue -- message queue
local maxclient -- max client
local client_number = 0
local CMD = setmetatable({}, { __gc = function() netpack.clear(queue) end })
local nodelay = false
local connection = {}
-- true : connected
-- nil : closed
-- false : close read
function gateserver.openclient(fd)
if connection[fd] then
socketdriver.start(fd)
end
end
function gateserver.closeclient(fd)
local c = connection[fd]
if c ~= nil then
connection[fd] = nil
socketdriver.close(fd)
end
end
function gateserver.start(handler)
assert(handler.message)
assert(handler.connect)
local listen_context = {}
function CMD.open( source, conf )
assert(not socket)
local address = conf.address or "0.0.0.0"
local port = assert(conf.port)
maxclient = conf.maxclient or 1024
nodelay = conf.nodelay
skynet.error(string.format("Listen on %s:%d", address, port))
socket = socketdriver.listen(address, port)
listen_context.co = coroutine.running()
listen_context.fd = socket
skynet.wait(listen_context.co)
conf.address = listen_context.addr
conf.port = listen_context.port
listen_context = nil
socketdriver.start(socket)
if handler.open then
return handler.open(source, conf)
end
end
function CMD.close()
assert(socket)
socketdriver.close(socket)
end
local MSG = {}
local function dispatch_msg(fd, msg, sz)
if connection[fd] then
handler.message(fd, msg, sz)
else
skynet.error(string.format("Drop message from fd (%d) : %s", fd, netpack.tostring(msg,sz)))
end
end
MSG.data = dispatch_msg
local function dispatch_queue()
local fd, msg, sz = netpack.pop(queue)
if fd then
-- may dispatch even the handler.message blocked
-- If the handler.message never block, the queue should be empty, so only fork once and then exit.
skynet.fork(dispatch_queue)
dispatch_msg(fd, msg, sz)
for fd, msg, sz in netpack.pop, queue do
dispatch_msg(fd, msg, sz)
end
end
end
MSG.more = dispatch_queue
function MSG.open(fd, msg)
client_number = client_number + 1
if client_number >= maxclient then
socketdriver.shutdown(fd)
return
end
if nodelay then
socketdriver.nodelay(fd)
end
connection[fd] = true
handler.connect(fd, msg)
end
function MSG.close(fd)
if fd ~= socket then
client_number = client_number - 1
if connection[fd] then
connection[fd] = false -- close read
end
if handler.disconnect then
handler.disconnect(fd)
end
else
socket = nil
end
end
function MSG.error(fd, msg)
if fd == socket then
skynet.error("gateserver accept error:",msg)
else
socketdriver.shutdown(fd)
if handler.error then
handler.error(fd, msg)
end
end
end
function MSG.warning(fd, size)
if handler.warning then
handler.warning(fd, size)
end
end
function MSG.init(id, addr, port)
if listen_context then
local co = listen_context.co
if co then
assert(id == listen_context.fd)
listen_context.addr = addr
listen_context.port = port
skynet.wakeup(co)
listen_context.co = nil
end
end
end
skynet.register_protocol {
name = "socket",
id = skynet.PTYPE_SOCKET, -- PTYPE_SOCKET = 6
unpack = function ( msg, sz )
return netpack.filter( queue, msg, sz)
end,
dispatch = function (_, _, q, type, ...)
queue = q
if type then
MSG[type](...)
end
end
}
local function init()
skynet.dispatch("lua", function (_, address, cmd, ...)
local f = CMD[cmd]
if f then
skynet.ret(skynet.pack(f(address, ...)))
else
skynet.ret(skynet.pack(handler.command(cmd, address, ...)))
end
end)
end
if handler.embed then
init()
else
skynet.start(init)
end
end
return gateserver