diff --git a/README.MD b/README.MD index 30b6d47..42396d0 100644 --- a/README.MD +++ b/README.MD @@ -31,10 +31,6 @@ For demostration purposes DDoS-protection mode was enabled by default. #### Installation Before installing the tool, ensure that HaProxy is built with Lua support. -- Install lua dependencies: -```bash -luarocks install --only-deps dependencies-0-0.rockspec -``` - Copy [scripts](src/scripts) to a folder accessible for HaProxy - Copy haproxy config and make sure that `lua-load` directive contains absolute path to [register.lua](src/scripts/register.lua) - Copy [libs](src/libs) to a path where Lua looks for modules. @@ -45,7 +41,7 @@ luarocks install --only-deps dependencies-0-0.rockspec The system comes with CLI. It can be used to manage global and per-domain protection. Ensure that stat socket is configured in HaProxy for CLI support. ```bash -Usage: ./ddos-cli [options] +Usage: ddos-cli [options] Command line interface to manage per-domain and global DDoS protection. @@ -53,16 +49,17 @@ optional arguments: -h, --help Show this help message and exit. Commands: - Global management: - ./ddos-cli global status Show status of global server ddos mode. - ./ddos-cli global enable Enable global ddos mode. - ./ddos-cli global disable Disable global ddos mode. + Global management: + ddos-cli global status Show status of global server ddos mode. + ddos-cli global enable Enable global ddos mode. + ddos-cli global disable Disable global ddos mode. + + Domain management: + ddos-cli domain list List all domains with ddos mode on. + ddos-cli domain status Get ddos mode status for a domain. + ddos-cli domain enable Enable ddos mode for a domain. + ddos-cli domain disable Disable ddos mode for a domain. - Domain management: - ./ddos-cli domain list List all domains with ddos mode on. - ./ddos-cli domain status Get ddos mode status for a domain. - ./ddos-cli domain add Enable ddos mode for a domain. - ./ddos-cli domain del Disable ddos mode for a domain. ``` diff --git a/dependencies-0-0.rockspec b/dependencies-0-0.rockspec deleted file mode 100644 index fd9269f..0000000 --- a/dependencies-0-0.rockspec +++ /dev/null @@ -1,14 +0,0 @@ -rockspec_format = "3.0" -package = "dependencies" -version = "0-0" -source = { - url = "https://github.com/mora9715/haproxy_ddos_protector" -} -dependencies = { - "lua > 5.1", - "md5 >= 1.3", - "net-url >= 0.9", - "luasec >= 1.0", - "luasocket >= 2", - "rapidjson >= 0.7" -} \ No newline at end of file diff --git a/haproxy/Dockerfile b/haproxy/Dockerfile index 582be2f..751ba9b 100644 --- a/haproxy/Dockerfile +++ b/haproxy/Dockerfile @@ -97,25 +97,7 @@ ADD haproxy/docker-entrypoint.sh /usr/local/bin/ RUN ln -s usr/local/bin/docker-entrypoint.sh / # backwards compat # This is terrible mess but we need it for simple testing purposes of our POC -RUN apt-get update && apt-get install libssl-dev make nano wget gcc libreadline-dev unzip git socat cmake g++ -y -RUN wget http://www.lua.org/ftp/lua-5.3.5.tar.gz &&\ - tar -zxf lua-5.3.5.tar.gz &&\ - cd lua-5.3.5 &&\ - make linux test &&\ - make install - -RUN wget "https://luarocks.org/releases/luarocks-3.3.1.tar.gz" &&\ - tar zxpf luarocks-3.3.1.tar.gz &&\ - cd luarocks-3.3.1 &&\ - ./configure --with-lua-include=/usr/local/include --lua-version=5.3 --lua-suffix=5.3 &&\ - make &&\ - make install - -RUN /usr/local/bin/luarocks install luasocket &&\ - /usr/local/bin/luarocks install luasec &&\ - /usr/local/bin/luarocks install net-url &&\ - /usr/local/bin/luarocks install md5 &&\ - /usr/local/bin/luarocks install rapidjson +RUN apt-get update && apt-get install socat dnsutils -y ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"] diff --git a/src/cli/ddos-cli b/src/cli/ddos-cli index 2f84b12..1e8dd52 100755 --- a/src/cli/ddos-cli +++ b/src/cli/ddos-cli @@ -43,8 +43,8 @@ Commands: Domain management: $0 domain list List all domains with ddos mode on. $0 domain status Get ddos mode status for a domain. - $0 domain add Enable ddos mode for a domain. - $0 domain del Disable ddos mode for a domain. + $0 domain enable Enable ddos mode for a domain. + $0 domain disable Disable ddos mode for a domain. EOF } @@ -85,7 +85,7 @@ _domain_status() { fi } -_domain_add() { +_domain_enable() { local ddos_domains local domain_acl_id @@ -105,7 +105,7 @@ _domain_add() { echo "DDoS-protection mode was enabled for ${1}" } -_domain_del() { +_domain_disable() { local ddos_domains local domain_acl_id @@ -129,7 +129,7 @@ _global_status() { local global_ddos_acl_id local global_ddos_status - global_ddos_acl_id=$(_h_show_acl | grep ${HAPROXY_GLOBAL_ACL} | cut -d' ' -f1) + global_ddos_acl_id=$(_h_show_acl | grep ${HAPROXY_GLOBAL_ACL} | head -1 | cut -d' ' -f1) global_ddos_status=$(_h_show_acl "${global_ddos_acl_id}" | cut -d' ' -f2) if [[ ${global_ddos_status} -eq 0 ]]; then @@ -140,36 +140,40 @@ _global_status() { } _global_enable() { - local global_ddos_acl_id + declare -a global_ddos_acl_ids local global_ddos_status - global_ddos_acl_id=$(_h_show_acl | grep ${HAPROXY_GLOBAL_ACL} | cut -d' ' -f1) - global_ddos_status=$(_h_show_acl "${global_ddos_acl_id}" | cut -d' ' -f2) + global_ddos_acl_ids=($(_h_show_acl | grep ${HAPROXY_GLOBAL_ACL} | cut -d' ' -f1)) + global_ddos_status=$(_h_show_acl "${global_ddos_acl_ids[0]}" | cut -d' ' -f2) if [[ ${global_ddos_status} -eq 0 ]]; then echo "DDoS-protection mode is already enabled globally" exit 0 fi - _h_add_acl "${global_ddos_acl_id}" 0 &>/dev/null - _h_del_acl "${global_ddos_acl_id}" 1 &>/dev/null + for id in "${global_ddos_acl_ids[@]}"; do + _h_add_acl "${id}" 0 &>/dev/null + _h_del_acl "${id}" 1 &>/dev/null + done echo "DDoS-protection mode was enabled globally" } _global_disable() { - local global_ddos_acl_id + declare -a global_ddos_acl_ids local global_ddos_status - global_ddos_acl_id=$(_h_show_acl | grep ${HAPROXY_GLOBAL_ACL} | cut -d' ' -f1) - global_ddos_status=$(_h_show_acl "${global_ddos_acl_id}" | cut -d' ' -f2) + global_ddos_acl_ids=($(_h_show_acl | grep ${HAPROXY_GLOBAL_ACL} | cut -d' ' -f1)) + global_ddos_status=$(_h_show_acl "${global_ddos_acl_ids[0]}" | cut -d' ' -f2) if [[ ${global_ddos_status} -eq 1 ]]; then echo "DDoS-protection mode is already disabled globally" exit 0 fi - _h_add_acl "${global_ddos_acl_id}" 1 &>/dev/null - _h_del_acl "${global_ddos_acl_id}" 0 &>/dev/null + for id in "${global_ddos_acl_ids[@]}"; do + _h_add_acl "${id}" 1 &>/dev/null + _h_del_acl "${id}" 0 &>/dev/null + done echo "DDoS-protection mode was disabled globally" } @@ -188,12 +192,12 @@ _handle_domain_management() { status) _ensure_domain_passed "${2}" _domain_status "${2}";; - add) + enable) _ensure_domain_passed "${2}" - _domain_add "${2}";; - del) + _domain_enable "${2}";; + disable) _ensure_domain_passed "${2}" - _domain_del "${2}";; + _domain_disable "${2}";; *) _help; exit 1;; esac } diff --git a/src/libs/http.lua b/src/libs/http.lua new file mode 100644 index 0000000..13ec3cb --- /dev/null +++ b/src/libs/http.lua @@ -0,0 +1,779 @@ +-- +-- HTTP 1.1 library for HAProxy Lua modules +-- +-- The library is loosely modeled after Python's Requests Library +-- using the same field names and very similar calling conventions for +-- "HTTP verb" methods (where we use Lua specific named parameter support) +-- +-- In addition to client side, the library also supports server side request +-- parsing, where we utilize HAProxy Lua API for all heavy lifting. +-- +-- +-- Copyright (c) 2017-2020. Adis Nezirović +-- Copyright (c) 2017-2020. HAProxy Technologies, LLC. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 + +local _author = "Adis Nezirovic " +local _copyright = "Copyright 2017-2020. HAProxy Technologies, LLC." +local _version = "1.0.0" + +local json = require "json" + +-- Utility functions + +-- HTTP headers fetch helper +-- +-- Returns a header value(s) according to strategy (fold by default): +-- - single/string value for "fold", "first" and "last" strategies +-- - table for "all" strategy (for single value, a table with single element) +-- +-- @param hdrs table Headers table as received by http.get and friends +-- @param name string Header name +-- @param strategy string "multiple header values" handling strategy +-- @return header value (string or table) or nil +local function get_header(hdrs, name, strategy) + if hdrs == nil or name == nil then return nil end + + local v = hdrs[name:lower()] + if type(v) ~= "table" and strategy ~= "all" then return v end + + if strategy == nil or strategy == "fold" then + return table.concat(v, ",") + elseif strategy == "first" then + return v[1] + elseif strategy == "last" then + return v[#v] + elseif strategy == "all" then + if type(v) ~= "table" then + return {v} + else + return v + end + end +end + + + +-- HTTP headers iterator helper +-- +-- Returns key/value pairs for all header, making sure that returned values +-- are always of string type (if necessary, it folds multiple headers with +-- the same name) +-- +-- @param hdrs table Headers table as received by http.get and friends +-- @return header name/value iterator (suitable for use in "for" loops) +local function get_headers_folded(hdrs) + if hdrs == nil then + return function() end + end + + local function iter(t, k) + local v + k, v = next(t, k) + + if v ~= nil then + if type(v) ~= "table" then + return k, v + else + return k, table.concat(v, ",") + end + end + end + + return iter, hdrs, nil +end + +-- HTTP headers iterator +-- +-- Returns key/value pairs for all headers, for multiple headers with same name +-- it will return every name/value pair +-- (i.e. you can safely use it to process responses with 'Set-Cookie' header) +-- +-- @param hdrs table Headers table as received by http.get and friends +-- @return header name/value iterator (suitable for use in "for" loops) +local function get_headers_flattened(hdrs) + if hdrs == nil then + return function() end + end + + local k -- top level key (string) + local k_sub = 0 -- sub table key (integer), 0 if item not a table, + -- nil after last sub table iteration + local v_sub -- sub table + + return function () + local v + if k_sub == 0 then + k, v = next(hdrs, k) + if k == nil then return end + else + k_sub, v = next(v_sub, k_sub) + + if k_sub == nil then + k_sub = 0 + k, v = next(hdrs, k) + end + end + + if k == nil then return end + + if type(v) ~= "table" then + return k, v + else + v_sub = v + k_sub = k_sub + 1 + return k, v[k_sub] + end + end +end + + +--- Parse key/value pairs from a string +-- +-- @param s Lua string with (multiple) key/value pairs (separated by 'sep') +-- +-- @return Table with parsed keys and values or nil +local function parse_kv(s, sep) + if s == nil then return nil end + idx = 1 + result = {} + + while idx < s:len() do + i, j = s:find(sep, idx) + + if i == nil then + k, v = string.match(s:sub(idx), "^(.-)=(.*)$") + if k then result[k] = v end + break + end + + k, v = string.match(s:sub(idx, i-1), "^(.-)=(.*)$") + if k then result[k] = v end + idx = j + 1 + end + + if next(result) == nil then + return nil + else + return result + end +end + + +--- Namespace object which hosts HTTP verb methods and request/response classes +local M = {} + + +--- HTTP response class +M.response = {} +M.response.__index = M.response + +local _reason = { + [200] = "OK", + [201] = "Created", + [204] = "No Content", + [301] = "Moved Permanently", + [302] = "Found", + [400] = "Bad Request", + [403] = "Forbidden", + [404] = "Not Found", + [405] = "Method Not Allowed", + [408] = "Request Timeout", + [413] = "Payload Too Large", + [429] = "Too many requests", + [500] = "Internal Server Error", + [501] = "Not Implemented", + [502] = "Bad Gateway", + [503] = "Service Unavailable", + [504] = "Gateway Timeout" +} + +--- Creates HTTP response from scratch +-- +-- @param status_code HTTP status code +-- @param reason HTTP status code text (e.g. "OK" for 200 response) +-- @param headers HTTP response headers +-- @param request The HTTP request which triggered the response +-- @param encoding Default encoding for response or conversions +-- +-- @return response object +function M.response.create(t) + local self = setmetatable({}, M.response) + + if not t then + t = {} + end + + self.status_code = t.status_code or nil + self.reason = t.reason or _reason[self.status_code] or "" + self.headers = t.headers or {} + self.content = t.content or "" + self.request = t.request or nil + self.encoding = t.encoding or "utf-8" + + return self +end + +function M.response.send(self, applet) + applet:set_status(tonumber(self.status_code), self.reason) + + for k, v in pairs(self.headers) do + if type(v) == "table" then + for _, hdr_val in pairs(v) do + applet:add_header(k, hdr_val) + end + else + applet:add_header(k, v) + end + end + + if not self.headers["content-type"] then + if type(self.content) == "table" then + applet:add_header("content-type", "application/json; charset=" .. + self.encoding) + if next(self.content) == nil then + -- Return empty JSON object for empty Lua tables + -- (that makes more sense then returning []) + self.content = "{}" + else + self.content = json.encode(self.content) + end + else + applet:add_header("content-type", "text/plain; charset=" .. + self.encoding) + end + end + + if not self.headers["content-length"] then + applet:add_header("content-length", #tostring(self.content)) + end + + applet:start_response() + applet:send(tostring(self.content)) +end + +--- Convert response content to JSON +-- +-- @return Lua table (decoded json) +function M.response.json(self) + return json.decode(self.content) +end + +-- Response headers getter +-- +-- Returns a header value(s) according to strategy (fold by default): +-- - single/string value for "fold", "first" and "last" strategies +-- - table for "all" strategy (for single value, a table with single element) +-- +-- @param name string Header name +-- @param strategy string "multiple header values" handling strategy +-- @return header value (string or table) or nil +function M.response.get_header(self, name, strategy) + return get_header(self.headers, name, strategy) +end + +-- Response headers iterator +-- +-- Yields key/value pairs for all headers, making sure that returned values +-- are always of string type +-- +-- @param folded boolean Specifies whether to fold headers with same name +-- @return header name/value iterator (suitable for use in "for" loops) +function M.response.get_headers(self, folded) + if folded == true then + return get_headers_folded(self.headers) + else + return get_headers_flattened(self.headers) + end +end + + +--- HTTP request class (client or server side, depending on the constructor) +M.request = {} +M.request.__index = M.request + +--- HTTP request constructor +-- +-- Parses client HTTP request (as forwarded by HAProxy) +-- +-- @param applet HAProxy AppletHTTP Lua object +-- +-- @return Request object +function M.request.parse(applet) + local self = setmetatable({}, M.request) + self.method = applet.method + + if (applet.method == "POST" or applet.method == "PUT") and + applet.length > 0 then + self.data = applet:receive() + if self.data == "" then self.data = nil end + end + + self.headers = {} + for k, v in pairs(applet.headers) do + if (v[1]) then -- (non folded header with multiple values) + self.headers[k] = {} + for _, val in pairs(v) do + table.insert(self.headers[k], val) + end + else + self.headers[k] = v[0] + end + end + + if not self.headers["host"] then + return nil, "Bad request, no Host header specified" + end + + self.cookies = parse_kv(self.headers["cookie"], "; ") + + -- TODO: Patch ApletHTTP and add schema of request + local schema = applet.schema or "http" + local url = {schema, "://", self.headers["host"], applet.path} + + self.params = {} + if applet.qs:len() > 0 then + for _, arg in ipairs(core.tokenize(applet.qs, "&", true)) do + kv = core.tokenize(arg, "=", true) + self.params[kv[1]] = kv[2] + end + url[#url+1] = "?" + url[#url+1] = applet.qs + end + + self.url = table.concat(url) + + return self +end + +--- Parse HTTP POST data +-- +-- @return Table with submitted form data +function M.request.parse_multipart(self) + local result ={} + local ct = self.headers['content-type'] + local body = self.data + + if ct:match('^multipart/form[-]data;') then + local boundary = ct:match('^multipart/form[-]data; boundary=(.+)$') + if boundary == nil then + return nil, 'Could not parse boundary from Content-Type' + end + + local i = 1 + local j + local old_i + + while true do + i, j = body:find(boundary, i) + + if i == nil then break end + + if old_i then + local part = body:sub(old_i, i - 1) + local k, fn, t, v = part:match('^\r\n[cC]ontent[-][dD]isposition: form[-]data; name[=]"(.+)"; filename="(.+)"\r\n[cC]ontent[-][tT]ype: (.+)\r\n\r\n(.+)\r\n$') + + if k then + result[k] = { + filename = fn, + content_type = t, + data = v + } + else + k, v = part:match('^\r\n[cC]ontent[-][dD]isposition: form[-]data; name[=]"(.+)"\r\n\r\n(.+)\r\n$') + + if k then + result[k] = v + end + end + + end + + i = j + 1 + old_i = i + end + elseif ct == 'application/x-www-form-urlencoded' then + result = parse_kv(body, '&') + else + return nil, 'Unsupported Content-Type: ' .. ct + end + + if result == nil or not next(result) then + return nil, 'Could not parse form data' + end + + return result +end + +--- Reads (all) chunks from a HTTP response +-- +-- @param socket socket object (with already established tcp connection) +-- @param get_all boolean (true by default), collect all chunks at once +-- or yield every chunk separately. +-- +-- @return Full response payload or nil and an error message +function M.receive_chunked(socket, get_all) + if socket == nil then + return nil, "http.receive_chunked: Socket is nil" + end + local data = {} + + while true do + local chunk, err = socket:receive("*l") + + if chunk == nil then + return nil, "http.receive_chunked(): Receive error (chunk length): " .. tostring(err) + end + + local chunk_len = tonumber(chunk, 16) + if chunk_len == nil then + return nil, "http.receive_chunked(): Could not parse chunk length" + end + + if chunk_len == 0 then + -- TODO: support trailers + break + end + + -- Consume next chunk (including the \r\n) + chunk, err = socket:receive(chunk_len+2) + if chunk == nil then + return nil, "http.receive_chunked(): Receive error (chunk data): " .. tostring(err) + end + + -- Strip the \r\n before collection + local chunk_data = string.sub(chunk, 1, -3) + + if get_all == false then + return chunk_data + end + + table.insert(data, chunk_data) + end + + return table.concat(data) +end + + +-- Request headers getter +-- +-- Returns a header value(s) according to strategy (fold by default): +-- - single/string value for "fold", "first" and "last" strategies +-- - table for "all" strategy (for single value, a table with single element) +-- +-- @param name string Header name +-- @param strategy string "multiple header values" handling strategy +-- @return header value (string or table) or nil +function M.request.get_header(self, name, strategy) + return get_header(self.headers, name, strategy) +end + +-- Request headers iterator +-- +-- Yields key/value pairs for all headers, making sure that returned values +-- are always of string type +-- +-- @param hdrs table Headers table as received by http.get and friends +-- @param folded boolean Specifies whether to fold headers with same name +-- @return header name/value iterator (suitable for use in "for" loops) +function M.request.get_headers(self, folded) + if folded == true then + return get_headers_folded(self.headers) + else + return get_headers_flattened(self.headers) + end +end + +--- Creates HTTP request from scratch +-- +-- @param method HTTP method +-- @param url Valid HTTP url +-- @param headers Lua table with request headers +-- @param data Request content +-- @param params Lua table with request url arguments +-- @param auth (username, password) tuple for HTTP auth +-- +-- @return request object +function M.request.create(t) + local self = setmetatable({}, M.request) + + if t.method then + self.method = t.method:lower() + else + self.method = "get" + end + self.url = t.url or nil + self.headers = t.headers or {} + self.data = t.data or nil + self.params = t.params or {} + self.auth = t.auth or {} + + return self +end + +--- HTTP HEAD request +function M.head(t) + return M.send("HEAD", t) +end + +--- HTTP GET request +function M.get(t) + return M.send("GET", t) +end + +--- HTTP PUT request +function M.put(t) + return M.send("PUT", t) +end + +--- HTTP POST request +function M.post(t) + return M.send("POST", t) +end + +--- HTTP DELETE request +function M.delete(t) + return M.send("DELETE", t) +end + + +--- Send HTTP request +-- +-- @param method HTTP method +-- @param url Valid HTTP url (mandatory) +-- @param headers Lua table with request headers +-- @param data Request content +-- @param params Lua table with request url arguments +-- @param auth (username, password) tuple for HTTP auth +-- @param timeout Optional timeout for socket operations (5s by default) +-- +-- @return Response object or tuple (nil, msg) on errors + +-- Note that the prefered way to call this method is via Lua +-- "keyword arguments" convention, e.g. +-- http.get{uri="http://example.net"} +function M.send(method, t) + if type(t) ~= "table" then + return nil, "http." .. method:lower() .. + ": expecting Request object for named parameters" + end + + if type(t.url) ~= "string" then + return nil, "http." .. method:lower() .. ": 'url' parameter missing" + end + + local socket = core.tcp() + socket:settimeout(t.timeout or 5) + local connect + if t.url:sub(1, 7) ~= "http://" and t.url:sub(1, 8) ~= "https://" then + t.url = "http://" .. t.url + end + local schema, host, req_uri = t.url:match("^(.*)://(.-)(/.*)$") + + if not schema then + -- maybe path (request uri) is missing + schema, host = t.url:match("^(.*)://(.-)$") + if not schema then + return nil, "http." .. method:lower() .. ": Could not parse URL: " .. t.url + end + req_uri = "/" + end + + local addr, port = host:match("(.*):(%d+)") + + if schema == "http" then + connect = socket.connect + if not port then + addr = host + port = 80 + end + elseif schema == "https" then + connect = socket.connect_ssl + if not port then + addr = host + port = 443 + end + else + return nil, "http." .. method:lower() .. ": Invalid URL schema " .. tostring(schema) + end + print("ADDR IS", addr) + local c, err = connect(socket, addr, port) + + if c then + local req = {} + local hdr_tbl = {} + + if t.headers then + for k, v in pairs(t.headers) do + if type(v) == "table" then + table.insert(hdr_tbl, k .. ": " .. table.concat(v, ",")) + else + table.insert(hdr_tbl, k .. ": " .. tostring(v)) + end + end + else + t.headers = {} -- dummy table + end + + if not t.headers.host then + -- 'Host' header must be provided for HTTP/1.1 + table.insert(hdr_tbl, "host: " .. host) + end + + if not t.headers["accept"] then + table.insert(hdr_tbl, "accept: */*") + end + + if not t.headers["user-agent"] then + table.insert(hdr_tbl, "user-agent: haproxy-lua-http/1.0") + end + + if not t.headers.connection then + table.insert(hdr_tbl, "connection: close") + end + + if t.data then + req[4] = t.data + if not t.headers or not t.headers["content-length"] then + table.insert(hdr_tbl, "content-length: " .. tostring(#t.data)) + end + end + + req[1] = method .. " " .. req_uri .. " HTTP/1.1\r\n" + req[2] = table.concat(hdr_tbl, "\r\n") + req[3] = "\r\n\r\n" + + local r, e = socket:send(table.concat(req)) + + if not r then + socket:close() + return nil, "http." .. method:lower() .. ": " .. tostring(e) + end + + local line + r = M.response.create() + + while true do + line, err = socket:receive("*l") + + if not line then + socket:close() + return nil, "http." .. method:lower() .. + ": Receive error (headers): " .. err + end + + if line == "" then break end + + if not r.status_code then + _, r.status_code, r.reason = + line:match("(HTTP/1.[01]) (%d%d%d)(.*)") + if not _ then + socket:close() + return nil, "http." .. method:lower() .. + ": Could not parse request line" + end + r.status_code = tonumber(r.status_code) + else + local sep = line:find(":") + local hdr_name = line:sub(1, sep-1):lower() + local hdr_val = line:sub(sep+1):match("^%s*(.*%S)%s*$") or "" + + if r.headers[hdr_name] == nil then + r.headers[hdr_name] = hdr_val + elseif type(r.headers[hdr_name]) == "table" then + table.insert(r.headers[hdr_name], hdr_val) + else + r.headers[hdr_name] = { + r.headers[hdr_name], + hdr_val + } + end + end + end + + if method:lower() == "head" then + r.content = nil + socket:close() + return r + end + + if r.headers["content-length"] and tonumber(r.headers["content-length"]) > 0 then + r.content, err = socket:receive("*a") + + if not r.content then + socket:close() + return nil, "http." .. method:lower() .. + ": Receive error (content): " .. err + end + end + + if r.headers["transfer-encoding"] and r.headers["transfer-encoding"] == "chunked" then + r.content, err = M.receive_chunked(socket) + if r.content == nil then + socket:close() + return nil, err + end + end + + socket:close() + return r + else + return nil, "http." .. method:lower() .. ": Connection error: " .. tostring(err) + end +end + +M.base64 = {} + +--- URL safe base64 encoder +-- +-- Padding ('=') is omited, as permited per RFC +-- https://tools.ietf.org/html/rfc4648 +-- in order to follow JSON Web Signature RFC +-- https://tools.ietf.org/html/rfc7515 +-- +-- @param s String (can be binary data) to encode +-- @param enc Function which implements base64 encoder (e.g. HAProxy base64 fetch) +-- @return Encoded string +function M.base64.encode(s, enc) + if not s then return nil end + local u = enc(s) + + if not u then + return nil + end + + local pad_len = 2 - ((#s-1) % 3) + + if pad_len > 0 then + return u:sub(1, - pad_len - 1):gsub('[+]', '-'):gsub('[/]', '_') + else + return u:gsub('[+]', '-'):gsub('[/]', '_') + end +end + +--- URLsafe base64 decoder +-- +-- @param s Base64 string to decode +-- @param dec Function which implements base64 decoder (e.g. HAProxy b64dec fetch) +-- @return Decoded string (can be binary data) +function M.base64.decode(s, dec) + if not s then return nil end + + local e = s:gsub('[-]', '+'):gsub('[_]', '/') + return dec(e .. string.rep('=', 3 - ((#s - 1) % 4))) +end + +return M diff --git a/src/libs/json.lua b/src/libs/json.lua new file mode 100644 index 0000000..711ef78 --- /dev/null +++ b/src/libs/json.lua @@ -0,0 +1,388 @@ +-- +-- json.lua +-- +-- Copyright (c) 2020 rxi +-- +-- Permission is hereby granted, free of charge, to any person obtaining a copy of +-- this software and associated documentation files (the "Software"), to deal in +-- the Software without restriction, including without limitation the rights to +-- use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +-- of the Software, and to permit persons to whom the Software is furnished to do +-- so, subject to the following conditions: +-- +-- The above copyright notice and this permission notice shall be included in all +-- copies or substantial portions of the Software. +-- +-- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +-- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +-- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +-- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +-- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +-- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +-- SOFTWARE. +-- + +local json = { _version = "0.1.2" } + +------------------------------------------------------------------------------- +-- Encode +------------------------------------------------------------------------------- + +local encode + +local escape_char_map = { + [ "\\" ] = "\\", + [ "\"" ] = "\"", + [ "\b" ] = "b", + [ "\f" ] = "f", + [ "\n" ] = "n", + [ "\r" ] = "r", + [ "\t" ] = "t", +} + +local escape_char_map_inv = { [ "/" ] = "/" } +for k, v in pairs(escape_char_map) do + escape_char_map_inv[v] = k +end + + +local function escape_char(c) + return "\\" .. (escape_char_map[c] or string.format("u%04x", c:byte())) +end + + +local function encode_nil(val) + return "null" +end + + +local function encode_table(val, stack) + local res = {} + stack = stack or {} + + -- Circular reference? + if stack[val] then error("circular reference") end + + stack[val] = true + + if rawget(val, 1) ~= nil or next(val) == nil then + -- Treat as array -- check keys are valid and it is not sparse + local n = 0 + for k in pairs(val) do + if type(k) ~= "number" then + error("invalid table: mixed or invalid key types") + end + n = n + 1 + end + if n ~= #val then + error("invalid table: sparse array") + end + -- Encode + for i, v in ipairs(val) do + table.insert(res, encode(v, stack)) + end + stack[val] = nil + return "[" .. table.concat(res, ",") .. "]" + + else + -- Treat as an object + for k, v in pairs(val) do + if type(k) ~= "string" then + error("invalid table: mixed or invalid key types") + end + table.insert(res, encode(k, stack) .. ":" .. encode(v, stack)) + end + stack[val] = nil + return "{" .. table.concat(res, ",") .. "}" + end +end + + +local function encode_string(val) + return '"' .. val:gsub('[%z\1-\31\\"]', escape_char) .. '"' +end + + +local function encode_number(val) + -- Check for NaN, -inf and inf + if val ~= val or val <= -math.huge or val >= math.huge then + error("unexpected number value '" .. tostring(val) .. "'") + end + return string.format("%.14g", val) +end + + +local type_func_map = { + [ "nil" ] = encode_nil, + [ "table" ] = encode_table, + [ "string" ] = encode_string, + [ "number" ] = encode_number, + [ "boolean" ] = tostring, +} + + +encode = function(val, stack) + local t = type(val) + local f = type_func_map[t] + if f then + return f(val, stack) + end + error("unexpected type '" .. t .. "'") +end + + +function json.encode(val) + return ( encode(val) ) +end + + +------------------------------------------------------------------------------- +-- Decode +------------------------------------------------------------------------------- + +local parse + +local function create_set(...) + local res = {} + for i = 1, select("#", ...) do + res[ select(i, ...) ] = true + end + return res +end + +local space_chars = create_set(" ", "\t", "\r", "\n") +local delim_chars = create_set(" ", "\t", "\r", "\n", "]", "}", ",") +local escape_chars = create_set("\\", "/", '"', "b", "f", "n", "r", "t", "u") +local literals = create_set("true", "false", "null") + +local literal_map = { + [ "true" ] = true, + [ "false" ] = false, + [ "null" ] = nil, +} + + +local function next_char(str, idx, set, negate) + for i = idx, #str do + if set[str:sub(i, i)] ~= negate then + return i + end + end + return #str + 1 +end + + +local function decode_error(str, idx, msg) + local line_count = 1 + local col_count = 1 + for i = 1, idx - 1 do + col_count = col_count + 1 + if str:sub(i, i) == "\n" then + line_count = line_count + 1 + col_count = 1 + end + end + error( string.format("%s at line %d col %d", msg, line_count, col_count) ) +end + + +local function codepoint_to_utf8(n) + -- http://scripts.sil.org/cms/scripts/page.php?site_id=nrsi&id=iws-appendixa + local f = math.floor + if n <= 0x7f then + return string.char(n) + elseif n <= 0x7ff then + return string.char(f(n / 64) + 192, n % 64 + 128) + elseif n <= 0xffff then + return string.char(f(n / 4096) + 224, f(n % 4096 / 64) + 128, n % 64 + 128) + elseif n <= 0x10ffff then + return string.char(f(n / 262144) + 240, f(n % 262144 / 4096) + 128, + f(n % 4096 / 64) + 128, n % 64 + 128) + end + error( string.format("invalid unicode codepoint '%x'", n) ) +end + + +local function parse_unicode_escape(s) + local n1 = tonumber( s:sub(1, 4), 16 ) + local n2 = tonumber( s:sub(7, 10), 16 ) + -- Surrogate pair? + if n2 then + return codepoint_to_utf8((n1 - 0xd800) * 0x400 + (n2 - 0xdc00) + 0x10000) + else + return codepoint_to_utf8(n1) + end +end + + +local function parse_string(str, i) + local res = "" + local j = i + 1 + local k = j + + while j <= #str do + local x = str:byte(j) + + if x < 32 then + decode_error(str, j, "control character in string") + + elseif x == 92 then -- `\`: Escape + res = res .. str:sub(k, j - 1) + j = j + 1 + local c = str:sub(j, j) + if c == "u" then + local hex = str:match("^[dD][89aAbB]%x%x\\u%x%x%x%x", j + 1) + or str:match("^%x%x%x%x", j + 1) + or decode_error(str, j - 1, "invalid unicode escape in string") + res = res .. parse_unicode_escape(hex) + j = j + #hex + else + if not escape_chars[c] then + decode_error(str, j - 1, "invalid escape char '" .. c .. "' in string") + end + res = res .. escape_char_map_inv[c] + end + k = j + 1 + + elseif x == 34 then -- `"`: End of string + res = res .. str:sub(k, j - 1) + return res, j + 1 + end + + j = j + 1 + end + + decode_error(str, i, "expected closing quote for string") +end + + +local function parse_number(str, i) + local x = next_char(str, i, delim_chars) + local s = str:sub(i, x - 1) + local n = tonumber(s) + if not n then + decode_error(str, i, "invalid number '" .. s .. "'") + end + return n, x +end + + +local function parse_literal(str, i) + local x = next_char(str, i, delim_chars) + local word = str:sub(i, x - 1) + if not literals[word] then + decode_error(str, i, "invalid literal '" .. word .. "'") + end + return literal_map[word], x +end + + +local function parse_array(str, i) + local res = {} + local n = 1 + i = i + 1 + while 1 do + local x + i = next_char(str, i, space_chars, true) + -- Empty / end of array? + if str:sub(i, i) == "]" then + i = i + 1 + break + end + -- Read token + x, i = parse(str, i) + res[n] = x + n = n + 1 + -- Next token + i = next_char(str, i, space_chars, true) + local chr = str:sub(i, i) + i = i + 1 + if chr == "]" then break end + if chr ~= "," then decode_error(str, i, "expected ']' or ','") end + end + return res, i +end + + +local function parse_object(str, i) + local res = {} + i = i + 1 + while 1 do + local key, val + i = next_char(str, i, space_chars, true) + -- Empty / end of object? + if str:sub(i, i) == "}" then + i = i + 1 + break + end + -- Read key + if str:sub(i, i) ~= '"' then + decode_error(str, i, "expected string for key") + end + key, i = parse(str, i) + -- Read ':' delimiter + i = next_char(str, i, space_chars, true) + if str:sub(i, i) ~= ":" then + decode_error(str, i, "expected ':' after key") + end + i = next_char(str, i + 1, space_chars, true) + -- Read value + val, i = parse(str, i) + -- Set + res[key] = val + -- Next token + i = next_char(str, i, space_chars, true) + local chr = str:sub(i, i) + i = i + 1 + if chr == "}" then break end + if chr ~= "," then decode_error(str, i, "expected '}' or ','") end + end + return res, i +end + + +local char_func_map = { + [ '"' ] = parse_string, + [ "0" ] = parse_number, + [ "1" ] = parse_number, + [ "2" ] = parse_number, + [ "3" ] = parse_number, + [ "4" ] = parse_number, + [ "5" ] = parse_number, + [ "6" ] = parse_number, + [ "7" ] = parse_number, + [ "8" ] = parse_number, + [ "9" ] = parse_number, + [ "-" ] = parse_number, + [ "t" ] = parse_literal, + [ "f" ] = parse_literal, + [ "n" ] = parse_literal, + [ "[" ] = parse_array, + [ "{" ] = parse_object, +} + + +parse = function(str, idx) + local chr = str:sub(idx, idx) + local f = char_func_map[chr] + if f then + return f(str, idx) + end + decode_error(str, idx, "unexpected character '" .. chr .. "'") +end + + +function json.decode(str) + if type(str) ~= "string" then + error("expected argument of type string, got " .. type(str)) + end + local res, idx = parse(str, next_char(str, 1, space_chars, true)) + idx = next_char(str, idx, space_chars, true) + if idx <= #str then + decode_error(str, idx, "trailing garbage") + end + return res +end + + +return json diff --git a/src/libs/url.lua b/src/libs/url.lua new file mode 100644 index 0000000..0aa4ffc --- /dev/null +++ b/src/libs/url.lua @@ -0,0 +1,451 @@ +-- neturl.lua - a robust url parser and builder +-- +-- Bertrand Mansion, 2011-2013; License MIT +-- @module neturl +-- @alias M + +local M = {} +M.version = "0.9.0" + +--- url options +-- separator is set to `&` by default but could be anything like `&amp;` or `;` +-- @todo Add an option to limit the size of the argument table +M.options = { + separator = '&' +} + +--- list of known and common scheme ports +-- as documented in IANA URI scheme list +M.services = { + acap = 674, + cap = 1026, + dict = 2628, + ftp = 21, + gopher = 70, + http = 80, + https = 443, + iax = 4569, + icap = 1344, + imap = 143, + ipp = 631, + ldap = 389, + mtqp = 1038, + mupdate = 3905, + news = 2009, + nfs = 2049, + nntp = 119, + rtsp = 554, + sip = 5060, + snmp = 161, + telnet = 23, + tftp = 69, + vemmi = 575, + afs = 1483, + jms = 5673, + rsync = 873, + prospero = 191, + videotex = 516 +} + +local legal = { + ["-"] = true, ["_"] = true, ["."] = true, ["!"] = true, + ["~"] = true, ["*"] = true, ["'"] = true, ["("] = true, + [")"] = true, [":"] = true, ["@"] = true, ["&"] = true, + ["="] = true, ["+"] = true, ["$"] = true, [","] = true, + [";"] = true -- can be used for parameters in path +} + +local function decode(str, path) + local str = str + if not path then + str = str:gsub('+', ' ') + end + return (str:gsub("%%(%x%x)", function(c) + return string.char(tonumber(c, 16)) + end)) +end + +local function encode(str) + return (str:gsub("([^A-Za-z0-9%_%.%-%~])", function(v) + return string.upper(string.format("%%%02x", string.byte(v))) + end)) +end + +-- for query values, prefer + instead of %20 for spaces +local function encodeValue(str) + local str = encode(str) + return str:gsub('%%20', '+') +end + +local function encodeSegment(s) + local legalEncode = function(c) + if legal[c] then + return c + end + return encode(c) + end + return s:gsub('([^a-zA-Z0-9])', legalEncode) +end + +local function concat(s, u) + return s .. u:build() +end + +--- builds the url +-- @return a string representing the built url +function M:build() + local url = '' + if self.path then + local path = self.path + path:gsub("([^/]+)", function (s) return encodeSegment(s) end) + url = url .. tostring(path) + end + if self.query then + local qstring = tostring(self.query) + if qstring ~= "" then + url = url .. '?' .. qstring + end + end + if self.host then + local authority = self.host + if self.port and self.scheme and M.services[self.scheme] ~= self.port then + authority = authority .. ':' .. self.port + end + local userinfo + if self.user and self.user ~= "" then + userinfo = self.user + if self.password then + userinfo = userinfo .. ':' .. self.password + end + end + if userinfo and userinfo ~= "" then + authority = userinfo .. '@' .. authority + end + if authority then + if url ~= "" then + url = '//' .. authority .. '/' .. url:gsub('^/+', '') + else + url = '//' .. authority + end + end + end + if self.scheme then + url = self.scheme .. ':' .. url + end + if self.fragment then + url = url .. '#' .. self.fragment + end + return url +end + +--- builds the querystring +-- @param tab The key/value parameters +-- @param sep The separator to use (optional) +-- @param key The parent key if the value is multi-dimensional (optional) +-- @return a string representing the built querystring +function M.buildQuery(tab, sep, key) + local query = {} + if not sep then + sep = M.options.separator or '&' + end + local keys = {} + for k in pairs(tab) do + keys[#keys+1] = k + end + table.sort(keys) + for _,name in ipairs(keys) do + local value = tab[name] + name = encode(tostring(name)) + if key then + name = string.format('%s[%s]', tostring(key), tostring(name)) + end + if type(value) == 'table' then + query[#query+1] = M.buildQuery(value, sep, name) + else + local value = encodeValue(tostring(value)) + if value ~= "" then + query[#query+1] = string.format('%s=%s', name, value) + else + query[#query+1] = name + end + end + end + return table.concat(query, sep) +end + +--- Parses the querystring to a table +-- This function can parse multidimensional pairs and is mostly compatible +-- with PHP usage of brackets in key names like ?param[key]=value +-- @param str The querystring to parse +-- @param sep The separator between key/value pairs, defaults to `&` +-- @todo limit the max number of parameters with M.options.max_parameters +-- @return a table representing the query key/value pairs +function M.parseQuery(str, sep) + if not sep then + sep = M.options.separator or '&' + end + + local values = {} + for key,val in str:gmatch(string.format('([^%q=]+)(=*[^%q=]*)', sep, sep)) do + local key = decode(key) + local keys = {} + key = key:gsub('%[([^%]]*)%]', function(v) + -- extract keys between balanced brackets + if string.find(v, "^-?%d+$") then + v = tonumber(v) + else + v = decode(v) + end + table.insert(keys, v) + return "=" + end) + key = key:gsub('=+.*$', "") + key = key:gsub('%s', "_") -- remove spaces in parameter name + val = val:gsub('^=+', "") + + if not values[key] then + values[key] = {} + end + if #keys > 0 and type(values[key]) ~= 'table' then + values[key] = {} + elseif #keys == 0 and type(values[key]) == 'table' then + values[key] = decode(val) + end + + local t = values[key] + for i,k in ipairs(keys) do + if type(t) ~= 'table' then + t = {} + end + if k == "" then + k = #t+1 + end + if not t[k] then + t[k] = {} + end + if i == #keys then + t[k] = decode(val) + end + t = t[k] + end + end + setmetatable(values, { __tostring = M.buildQuery }) + return values +end + +--- set the url query +-- @param query Can be a string to parse or a table of key/value pairs +-- @return a table representing the query key/value pairs +function M:setQuery(query) + local query = query + if type(query) == 'table' then + query = M.buildQuery(query) + end + self.query = M.parseQuery(query) + return query +end + +--- set the authority part of the url +-- The authority is parsed to find the user, password, port and host if available. +-- @param authority The string representing the authority +-- @return a string with what remains after the authority was parsed +function M:setAuthority(authority) + self.authority = authority + self.port = nil + self.host = nil + self.userinfo = nil + self.user = nil + self.password = nil + + authority = authority:gsub('^([^@]*)@', function(v) + self.userinfo = v + return '' + end) + authority = authority:gsub("^%[[^%]]+%]", function(v) + -- ipv6 + self.host = v + return '' + end) + authority = authority:gsub(':([^:]*)$', function(v) + self.port = tonumber(v) + return '' + end) + if authority ~= '' and not self.host then + self.host = authority:lower() + end + if self.userinfo then + local userinfo = self.userinfo + userinfo = userinfo:gsub(':([^:]*)$', function(v) + self.password = v + return '' + end) + self.user = userinfo + end + return authority +end + +--- Parse the url into the designated parts. +-- Depending on the url, the following parts can be available: +-- scheme, userinfo, user, password, authority, host, port, path, +-- query, fragment +-- @param url Url string +-- @return a table with the different parts and a few other functions +function M.parse(url) + local comp = {} + M.setAuthority(comp, "") + M.setQuery(comp, "") + + local url = tostring(url or '') + url = url:gsub('#(.*)$', function(v) + comp.fragment = v + return '' + end) + url =url:gsub('^([%w][%w%+%-%.]*)%:', function(v) + comp.scheme = v:lower() + return '' + end) + url = url:gsub('%?(.*)', function(v) + M.setQuery(comp, v) + return '' + end) + url = url:gsub('^//([^/]*)', function(v) + M.setAuthority(comp, v) + return '' + end) + comp.path = decode(url, true) + + setmetatable(comp, { + __index = M, + __concat = concat, + __tostring = M.build} + ) + return comp +end + +--- removes dots and slashes in urls when possible +-- This function will also remove multiple slashes +-- @param path The string representing the path to clean +-- @return a string of the path without unnecessary dots and segments +function M.removeDotSegments(path) + local fields = {} + if string.len(path) == 0 then + return "" + end + local startslash = false + local endslash = false + if string.sub(path, 1, 1) == "/" then + startslash = true + end + if (string.len(path) > 1 or startslash == false) and string.sub(path, -1) == "/" then + endslash = true + end + + path:gsub('[^/]+', function(c) table.insert(fields, c) end) + + local new = {} + local j = 0 + + for i,c in ipairs(fields) do + if c == '..' then + if j > 0 then + j = j - 1 + end + elseif c ~= "." then + j = j + 1 + new[j] = c + end + end + local ret = "" + if #new > 0 and j > 0 then + ret = table.concat(new, '/', 1, j) + else + ret = "" + end + if startslash then + ret = '/'..ret + end + if endslash then + ret = ret..'/' + end + return ret +end + +local function absolutePath(base_path, relative_path) + if string.sub(relative_path, 1, 1) == "/" then + return '/' .. string.gsub(relative_path, '^[%./]+', '') + end + local path = base_path + if relative_path ~= "" then + path = '/'..path:gsub("[^/]*$", "") + end + path = path .. relative_path + path = path:gsub("([^/]*%./)", function (s) + if s ~= "./" then return s else return "" end + end) + path = string.gsub(path, "/%.$", "/") + local reduced + while reduced ~= path do + reduced = path + path = string.gsub(reduced, "([^/]*/%.%./)", function (s) + if s ~= "../../" then return "" else return s end + end) + end + path = string.gsub(path, "([^/]*/%.%.?)$", function (s) + if s ~= "../.." then return "" else return s end + end) + local reduced + while reduced ~= path do + reduced = path + path = string.gsub(reduced, '^/?%.%./', '') + end + return '/' .. path +end + +--- builds a new url by using the one given as parameter and resolving paths +-- @param other A string or a table representing a url +-- @return a new url table +function M:resolve(other) + if type(self) == "string" then + self = M.parse(self) + end + if type(other) == "string" then + other = M.parse(other) + end + if other.scheme then + return other + else + other.scheme = self.scheme + if not other.authority or other.authority == "" then + other:setAuthority(self.authority) + if not other.path or other.path == "" then + other.path = self.path + local query = other.query + if not query or not next(query) then + other.query = self.query + end + else + other.path = absolutePath(self.path, other.path) + end + end + return other + end +end + +--- normalize a url path following some common normalization rules +-- described on The URL normalization page of Wikipedia +-- @return the normalized path +function M:normalize() + if type(self) == 'string' then + self = M.parse(self) + end + if self.path then + local path = self.path + path = absolutePath(path, "") + -- normalize multiple slashes + path = string.gsub(path, "//+", "/") + self.path = path + end + return self +end + +return M diff --git a/src/libs/utils.lua b/src/libs/utils.lua index e70268a..6f37054 100644 --- a/src/libs/utils.lua +++ b/src/libs/utils.lua @@ -1,17 +1,18 @@ local _M = {} -local md5 = require("md5") function _M.get_hostname() - local f = io.popen ("/bin/hostname") - local hostname = f:read("*a") or "" - f:close() + local handler = io.popen ("/bin/hostname") + local hostname = handler:read("*a") or "" + handler:close() hostname =string.gsub(hostname, "\n$", "") return hostname end -function _M.get_floating_hash() - -- This ensures that a cookie is rotated every day - return md5.sumhexa(_M.get_hostname() .. os.date("%d")) +function _M.resolve_fqdn(fqdn) + local handler = io.popen(string.format("dig +short %s | head -1", fqdn)) + local result = handler:read("*a") + handler:close() + return result:gsub("\n", "") end return _M diff --git a/src/scripts/hcaptcha.lua b/src/scripts/hcaptcha.lua index a5bcc5f..92a48c2 100644 --- a/src/scripts/hcaptcha.lua +++ b/src/scripts/hcaptcha.lua @@ -1,14 +1,17 @@ _M = {} -local url = require("net.url") -local https = require("ssl.https") -local json = require("rapidjson") +local url = require("url") +local http = require("http") local utils = require("utils") local cookie = require("cookie") +local json = require("json") -local floating_hash = utils.get_floating_hash() -local hcaptcha_secret = os.getenv("HCAPTCHA_SECRET") -local hcaptcha_sitekey = os.getenv("HCAPTCHA_SITEKEY") +local captcha_secret = os.getenv("HCAPTCHA_SECRET") +local captcha_sitekey = os.getenv("HCAPTCHA_SITEKEY") + +-- HaProxy Lua is not capable of FQDN resolution :( +local captcha_provider_domain = "hcaptcha.com" +local captcha_provider_ip = utils.resolve_fqdn(captcha_provider_domain) function _M.view(applet) local response_body @@ -42,7 +45,7 @@ function _M.view(applet) ]] - response_body = string.format(response_body, hcaptcha_sitekey) + response_body = string.format(response_body, captcha_sitekey) response_status_code = 200 elseif applet.method == "POST" then local parsed_body = url.parseQuery(applet.receive(applet)) @@ -50,21 +53,29 @@ function _M.view(applet) if parsed_body["h-captcha-response"] then local url = string.format( - "https://hcaptcha.com/siteverify?secret=%s&response=%s", - hcaptcha_secret, + "https://%s/siteverify?secret=%s&response=%s", + captcha_provider_ip, + captcha_secret, parsed_body["h-captcha-response"] ) - local body, _, _, _ = https.request(url) - local api_response = json.decode(body) + local res, err = http.get{url=url, headers={host=captcha_provider_domain} } + local status, api_response = pcall(res.json, res) + + if not status then + local original_error = api_response + api_response = {} + core.Warning("Received incorrect response from Captcha Provider: " .. original_error) + end if api_response.success == true then + local floating_hash = applet.sc:xxh32(utils.get_hostname()) core.Debug("HCAPTCHA SUCCESSFULLY PASSED") applet:add_header( "set-cookie", string.format("z_ddos_protection=%s; Max-Age=14400; Path=/", floating_hash) ) else - core.Debug("HCAPTCHA FAILED: " .. body) + core.Debug("HCAPTCHA FAILED: " .. json.encode(api_response)) end end @@ -84,11 +95,12 @@ function _M.check_captcha_status(txn) core.Debug("CAPTCHA STATUS CHECK START") txn:set_var("txn.requested_url", "/mopsik?kek=pek") local parsed_request_cookies = cookie.get_cookie_table(txn.sf:hdr("Cookie")) + local expected_cookie = txn.sc:xxh32(utils.get_hostname()) core.Debug("RECEIVED SECRET COOKIE: " .. parsed_request_cookies["z_ddos_protection"]) - core.Debug("OUR SECRET COOKIE: " .. floating_hash) + core.Debug("OUR SECRET COOKIE: " .. expected_cookie) - if parsed_request_cookies["z_ddos_protection"] == floating_hash then + if parsed_request_cookies["z_ddos_protection"] == expected_cookie then core.Debug("CAPTCHA STATUS CHECK SUCCESS") return txn:set_var("txn.captcha_passed", true) end