You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

532 lines
15 KiB
Lua

local internal = require "http.internal"
local socket = require "skynet.socket"
local crypt = require "skynet.crypt"
local httpd = require "http.httpd"
local skynet = require "skynet"
local sockethelper = require "http.sockethelper"
local socket_error = sockethelper.socket_error
local GLOBAL_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
local MAX_FRAME_SIZE = 256 * 1024 -- max frame is 256K
local M = {}
local ws_pool = {}
local function _close_websocket(ws_obj)
local id = ws_obj.id
assert(ws_pool[id] == ws_obj)
ws_pool[id] = nil
ws_obj.close()
end
local function _isws_closed(id)
return not ws_pool[id]
end
local function write_handshake(self, host, url, header)
local key = crypt.base64encode(crypt.randomkey()..crypt.randomkey())
local request_header = {
["Upgrade"] = "websocket",
["Connection"] = "Upgrade",
["Sec-WebSocket-Version"] = "13",
["Sec-WebSocket-Key"] = key
}
if header then
for k,v in pairs(header) do
assert(request_header[k] == nil, k)
request_header[k] = v
end
end
local recvheader = {}
local code, body = internal.request(self, "GET", host, url, recvheader, request_header)
if code ~= 101 then
error(string.format("websocket handshake error: code[%s] info:%s", code, body))
end
assert(body == "") -- todo: M.read may need handle it
if not recvheader["upgrade"] or recvheader["upgrade"]:lower() ~= "websocket" then
error("websocket handshake upgrade must websocket")
end
if not recvheader["connection"] or recvheader["connection"]:lower() ~= "upgrade" then
error("websocket handshake connection must upgrade")
end
local sw_key = recvheader["sec-websocket-accept"]
if not sw_key then
error("websocket handshake need Sec-WebSocket-Accept")
end
local guid = self.guid
sw_key = crypt.base64decode(sw_key)
if sw_key ~= crypt.sha1(key .. guid) then
error("websocket handshake invalid Sec-WebSocket-Accept")
end
end
local function read_handshake(self, upgrade_ops)
local header, method, url
if upgrade_ops then
header, method, url = upgrade_ops.header, upgrade_ops.method, upgrade_ops.url
else
local tmpline = {}
local header_body = internal.recvheader(self.read, tmpline, "")
if not header_body then
return 413
end
local request = assert(tmpline[1])
local httpver
method, url, httpver = request:match "^(%a+)%s+(.-)%s+HTTP/([%d%.]+)$"
assert(method and url and httpver)
if method ~= "GET" then
return 400, "need GET method"
end
httpver = assert(tonumber(httpver))
if httpver < 1.1 then
return 505 -- HTTP Version not supported
end
header = internal.parseheader(tmpline, 2, {})
end
if not header then
return 400 -- Bad request
end
if not header["upgrade"] or header["upgrade"]:lower() ~= "websocket" then
return 426, "Upgrade Required"
end
if not header["host"] then
return 400, "host Required"
end
if not header["connection"] or not header["connection"]:lower():find("upgrade", 1,true) then
return 400, "Connection must Upgrade"
end
local sw_key = header["sec-websocket-key"]
if not sw_key then
return 400, "Sec-WebSocket-Key Required"
else
local raw_key = crypt.base64decode(sw_key)
if #raw_key ~= 16 then
return 400, "Sec-WebSocket-Key invalid"
end
end
if not header["sec-websocket-version"] or header["sec-websocket-version"] ~= "13" then
return 400, "Sec-WebSocket-Version must 13"
end
local sw_protocol = header["sec-websocket-protocol"]
local sub_pro = ""
if sw_protocol then
local has_chat = false
for sub_protocol in string.gmatch(sw_protocol, "[^%s,]+") do
if sub_protocol == "chat" then
sub_pro = "Sec-WebSocket-Protocol: chat\r\n"
has_chat = true
break
end
end
if not has_chat then
return 400, "Sec-WebSocket-Protocol need include chat"
end
end
-- read 'x-real-ip' header from nginx
self.real_ip = header["x-real-ip"]
-- response handshake
local accept = crypt.base64encode(crypt.sha1(sw_key .. self.guid))
local resp = "HTTP/1.1 101 Switching Protocols\r\n"..
"Upgrade: websocket\r\n"..
"Connection: Upgrade\r\n"..
string.format("Sec-WebSocket-Accept: %s\r\n", accept)..
sub_pro ..
"\r\n"
self.write(resp)
return nil, header, url
end
local function try_handle(self, method, ...)
local handle = self.handle
local f = handle and handle[method]
if f then
f(self.id, ...)
end
end
local op_code = {
["frame"] = 0x00,
["text"] = 0x01,
["binary"] = 0x02,
["close"] = 0x08,
["ping"] = 0x09,
["pong"] = 0x0A,
[0x00] = "frame",
[0x01] = "text",
[0x02] = "binary",
[0x08] = "close",
[0x09] = "ping",
[0x0A] = "pong",
}
local function write_frame(self, op, payload_data, masking_key)
payload_data = payload_data or ""
local payload_len = #payload_data
local op_v = assert(op_code[op])
local v1 = 0x80 | op_v -- fin is 1 with opcode
local s
local mask = masking_key and 0x80 or 0x00
-- mask set to 0
if payload_len < 126 then
s = string.pack("I1I1", v1, mask | payload_len)
elseif payload_len <= 0xffff then
s = string.pack("I1I1>I2", v1, mask | 126, payload_len)
else
s = string.pack("I1I1>I8", v1, mask | 127, payload_len)
end
self.write(s)
-- write masking_key
if masking_key then
s = string.pack(">I4", masking_key)
self.write(s)
payload_data = crypt.xor_str(payload_data, s)
end
if payload_len > 0 then
self.write(payload_data)
end
end
local function read_close(payload_data)
local code, reason
local payload_len = #payload_data
if payload_len > 2 then
local fmt = string.format(">I2c%d", payload_len - 2)
code, reason = string.unpack(fmt, payload_data)
end
return code, reason
end
local function read_frame(self)
local s = self.read(2)
local v1, v2 = string.unpack("I1I1", s)
local fin = (v1 & 0x80) ~= 0
-- unused flag
-- local rsv1 = (v1 & 0x40) ~= 0
-- local rsv2 = (v1 & 0x20) ~= 0
-- local rsv3 = (v1 & 0x10) ~= 0
local op = v1 & 0x0f
local mask = (v2 & 0x80) ~= 0
local payload_len = (v2 & 0x7f)
if payload_len == 126 then
s = self.read(2)
payload_len = string.unpack(">I2", s)
elseif payload_len == 127 then
s = self.read(8)
payload_len = string.unpack(">I8", s)
end
if self.mode == "server" and payload_len > MAX_FRAME_SIZE then
error("payload_len is too large")
end
-- print(string.format("fin:%s, op:%s, mask:%s, payload_len:%s", fin, op_code[op], mask, payload_len))
local masking_key = mask and self.read(4) or false
local payload_data = payload_len>0 and self.read(payload_len) or ""
payload_data = masking_key and crypt.xor_str(payload_data, masking_key) or payload_data
return fin, assert(op_code[op]), payload_data
end
local function resolve_accept(self, options)
try_handle(self, "connect")
local code, err, url = read_handshake(self, options and options.upgrade)
if code then
local ok, s = httpd.write_response(self.write, code, err)
if not ok then
error(s)
end
try_handle(self, "close")
return
end
local header = err
try_handle(self, "handshake", header, url)
local recv_count = 0
local recv_buf = {}
local first_op
while true do
if _isws_closed(self.id) then
try_handle(self, "close")
return
end
local fin, op, payload_data = read_frame(self)
if op == "close" then
local code, reason = read_close(payload_data)
write_frame(self, "close")
try_handle(self, "close", code, reason)
break
elseif op == "ping" then
write_frame(self, "pong", payload_data)
try_handle(self, "ping")
elseif op == "pong" then
try_handle(self, "pong")
else
if fin and #recv_buf == 0 then
try_handle(self, "message", payload_data, op)
else
recv_buf[#recv_buf+1] = payload_data
recv_count = recv_count + #payload_data
if recv_count > MAX_FRAME_SIZE then
error("payload_len is too large")
end
first_op = first_op or op
if fin then
local s = table.concat(recv_buf)
try_handle(self, "message", s, first_op)
recv_buf = {} -- clear recv_buf
recv_count = 0
first_op = nil
end
end
end
end
end
local SSLCTX_CLIENT = nil
local function _new_client_ws(socket_id, protocol, hostname)
local obj
if protocol == "ws" then
obj = {
close = function ()
socket.close(socket_id)
end,
read = sockethelper.readfunc(socket_id),
write = sockethelper.writefunc(socket_id),
readall = function ()
return socket.readall(socket_id)
end,
}
elseif protocol == "wss" then
local tls = require "http.tlshelper"
SSLCTX_CLIENT = SSLCTX_CLIENT or tls.newctx()
local tls_ctx = tls.newtls("client", SSLCTX_CLIENT, hostname)
local init = tls.init_requestfunc(socket_id, tls_ctx)
init()
obj = {
close = function ()
socket.close(socket_id)
tls.closefunc(tls_ctx)()
end,
read = tls.readfunc(socket_id, tls_ctx),
write = tls.writefunc(socket_id, tls_ctx),
readall = tls.readallfunc(socket_id, tls_ctx),
}
else
error(string.format("invalid websocket protocol:%s", tostring(protocol)))
end
obj.mode = "client"
obj.id = assert(socket_id)
obj.guid = GLOBAL_GUID
ws_pool[socket_id] = obj
return obj
end
local SSLCTX_SERVER = nil
local function _new_server_ws(socket_id, handle, protocol)
local obj
if protocol == "ws" then
obj = {
close = function ()
socket.close(socket_id)
end,
read = sockethelper.readfunc(socket_id),
write = sockethelper.writefunc(socket_id),
}
elseif protocol == "wss" then
local tls = require "http.tlshelper"
if not SSLCTX_SERVER then
SSLCTX_SERVER = tls.newctx()
-- gen cert and key
-- openssl req -x509 -newkey rsa:2048 -days 3650 -nodes -keyout server-key.pem -out server-cert.pem
local certfile = skynet.getenv("certfile") or "./server-cert.pem"
local keyfile = skynet.getenv("keyfile") or "./server-key.pem"
SSLCTX_SERVER:set_cert(certfile, keyfile)
end
local tls_ctx = tls.newtls("server", SSLCTX_SERVER)
local init = tls.init_responsefunc(socket_id, tls_ctx)
init()
obj = {
close = function ()
socket.close(socket_id)
tls.closefunc(tls_ctx)()
end,
read = tls.readfunc(socket_id, tls_ctx),
write = tls.writefunc(socket_id, tls_ctx),
}
else
error(string.format("invalid websocket protocol:%s", tostring(protocol)))
end
obj.mode = "server"
obj.id = assert(socket_id)
obj.handle = handle
obj.guid = GLOBAL_GUID
ws_pool[socket_id] = obj
return obj
end
-- handle interface
-- connect / handshake / message / ping / pong / close / error
function M.accept(socket_id, handle, protocol, addr, options)
if not (options and options.upgrade) then
socket.start(socket_id)
end
protocol = protocol or "ws"
local ws_obj = _new_server_ws(socket_id, handle, protocol)
ws_obj.addr = addr
local on_warning = handle and handle["warning"]
if on_warning then
socket.warning(socket_id, function (id, sz)
on_warning(ws_obj, sz)
end)
end
local ok, err = xpcall(resolve_accept, debug.traceback, ws_obj, options)
local closed = _isws_closed(socket_id)
if not closed then
_close_websocket(ws_obj)
end
if not ok then
if err == socket_error then
if closed then
try_handle(ws_obj, "close")
else
print("socket error err = ", err)
try_handle(ws_obj, "error")
end
else
-- error(err)
return false, err
end
end
return true
end
function M.connect(url, header, timeout)
local protocol, host, uri = string.match(url, "^(wss?)://([^/]+)(.*)$")
if protocol ~= "wss" and protocol ~= "ws" then
error(string.format("invalid protocol: %s", protocol))
end
assert(host)
local host_addr, host_port = string.match(host, "^([^:]+):?(%d*)$")
assert(host_addr and host_port)
if host_port == "" then
host_port = protocol == "ws" and 80 or 443
end
local hostname
if not host_addr:match(".*%d+$") then
hostname = host_addr
end
uri = uri == "" and "/" or uri
local socket_id = sockethelper.connect(host_addr, host_port, timeout)
local ws_obj = _new_client_ws(socket_id, protocol, hostname)
ws_obj.addr = host
write_handshake(ws_obj, host_addr, uri, header)
return socket_id
end
function M.read(id)
local ws_obj = assert(ws_pool[id])
local recv_buf
while true do
local fin, op, payload_data = read_frame(ws_obj)
if op == "close" then
_close_websocket(ws_obj)
return false, payload_data
elseif op == "ping" then
write_frame(ws_obj, "pong", payload_data)
elseif op ~= "pong" then -- op is frame, text binary
if fin and not recv_buf then
return payload_data
else
recv_buf = recv_buf or {}
recv_buf[#recv_buf+1] = payload_data
if fin then
local s = table.concat(recv_buf)
return s
end
end
end
end
end
function M.write(id, data, fmt, masking_key)
local ws_obj = assert(ws_pool[id])
fmt = fmt or "text"
assert(fmt == "text" or fmt == "binary")
write_frame(ws_obj, fmt, data, masking_key)
end
function M.ping(id)
local ws_obj = assert(ws_pool[id])
write_frame(ws_obj, "ping")
end
function M.addrinfo(id)
local ws_obj = assert(ws_pool[id])
return ws_obj.addr
end
function M.real_ip(id)
local ws_obj = assert(ws_pool[id])
return ws_obj.real_ip
end
function M.close(id, code ,reason)
local ws_obj = ws_pool[id]
if not ws_obj then
return
end
local ok, err = xpcall(function ()
reason = reason or ""
local payload_data
if code then
local fmt =string.format(">I2c%d", #reason)
payload_data = string.pack(fmt, code, reason)
end
write_frame(ws_obj, "close", payload_data)
end, debug.traceback)
_close_websocket(ws_obj)
if not ok then
skynet.error(err)
end
end
return M