Skip to content

Writing Lua extension library using Rust

Bruce edited this page Dec 24, 2024 · 2 revisions

介绍

Rust有完善的包管理机制,借助Rust可以极大丰富Lua扩展库,如依赖tokiohttps client,sqlx 等网络相关的库。使用Rust编写Lua扩展库也是比较容易的, 因为Rust本身提供编写动态库给C/Cpp调用。

编写Lua扩展库需要依赖lib-lua-sys, 这里参考了mlua-sys(Low level (FFI) bindings to Lua 5.4/5.3/5.2/5.1 (including LuaJIT) and Roblox Luau)。 Rust mlua库本身是支持编写lua扩展库的, 但它比较复杂,并且使用起来没有 lua api灵活,所以这里只使用它的 ffi bingding部分。同时lib-lua-sys也做了一些改动, mlua默认是静态link的lua库, 由于要给moon编写扩展库lib-lua-sys是动态link lua库的。

编写基础库

新建Rust项目, 手动添加lib-lua-sys。具体Rust包管理机制请参考Rust相关文档,这里不再详细描述。

[dependencies]
lib-core = { path = "../../libs/lib-core"}
lib-lua = {package = "lib-lua-sys", path = "../../libs/lib-lua-sys",features = ["lua54"]}

然后就可以使用类似 lua c api 的方式编写lua扩展库了, 如lua_excel举例

use calamine::{open_workbook, Data, Reader, Xlsx};
use csv::ReaderBuilder;
use lib_lua::{self, cstr, ffi, ffi::luaL_Reg, laux, lreg, lreg_null};
use std::{os::raw::c_int, path::Path};

fn read_csv(state: *mut ffi::lua_State, path: &Path, max_row: usize) -> c_int {
    let res = ReaderBuilder::new().has_headers(false).from_path(path);
    unsafe {
        ffi::lua_createtable(state, 0, 0);
    }

    match res {
        Ok(mut reader) => {
            unsafe {
                ffi::lua_createtable(state, 0, 2);
                laux::lua_push(
                    state,
                    path.file_stem()
                        .unwrap_or_default()
                        .to_str()
                        .unwrap_or_default(),
                );
                ffi::lua_setfield(state, -2, cstr!("sheet_name"));
                ffi::lua_createtable(state, 1024, 0);
            }

            let mut idx: usize = 0;

            for result in reader.records() {
                if idx >= max_row {
                    break;
                }
                match result {
                    Ok(record) => unsafe {
                        ffi::lua_createtable(state, 0, record.len() as i32);
                        for (i, field) in record.iter().enumerate() {
                            laux::lua_push(state, field);
                            ffi::lua_rawseti(state, -2, (i + 1) as i64);
                        }
                        idx += 1;
                        ffi::lua_rawseti(state, -2, idx as i64);
                    },
                    Err(err) => unsafe {
                        ffi::lua_pushboolean(state, 0);
                        laux::lua_push(
                            state,
                            format!("read csv '{}' error: {}", path.to_string_lossy(), err)
                                .as_str(),
                        );
                        return 2;
                    },
                }
            }

            unsafe {
                ffi::lua_setfield(state, -2, cstr!("data"));
                ffi::lua_rawseti(state, -2, 1);
            }
            1
        }
        Err(err) => {
            unsafe {
                ffi::lua_pushboolean(state, 0);
            }

            laux::lua_push(
                state,
                format!("open file '{}' error: {}", path.to_string_lossy(), err).as_str(),
            );
            2
        }
    }
}

fn read_xlxs(state: *mut ffi::lua_State, path: &Path, max_row: usize) -> c_int {
    let res: Result<Xlsx<_>, _> = open_workbook(path);
    match res {
        Ok(mut workbook) => {
            unsafe {
                ffi::lua_createtable(state, 0, 0);
            }
            let mut sheet_counter = 0;
            workbook.sheet_names().iter().for_each(|sheet| {
                if let Ok(range) = workbook.worksheet_range(sheet) {
                    unsafe {
                        ffi::lua_createtable(state, 0, 2);
                        laux::lua_push(state, sheet.as_str());

                        ffi::lua_setfield(state, -2, cstr!("sheet_name"));

                        ffi::lua_createtable(state, range.rows().len() as i32, 0);
                        for (i, row) in range.rows().enumerate() {
                            if i >= max_row {
                                break;
                            }
                            //rows
                            ffi::lua_createtable(state, row.len() as i32, 0);

                            for (j, cell) in row.iter().enumerate() {
                                //columns

                                match cell {
                                    Data::Int(v) => {
                                        ffi::lua_pushinteger(state, *v as ffi::lua_Integer)
                                    }
                                    Data::Float(v) => ffi::lua_pushnumber(state, *v),
                                    Data::String(v) => laux::lua_push(state, v.as_str()),
                                    Data::Bool(v) => ffi::lua_pushboolean(state, *v as i32),
                                    Data::Error(v) => laux::lua_push(state, v.to_string()),
                                    Data::Empty => ffi::lua_pushnil(state),
                                    Data::DateTime(v) => laux::lua_push(state, v.to_string()),
                                    _ => ffi::lua_pushnil(state),
                                }
                                ffi::lua_rawseti(state, -2, (j + 1) as i64);
                            }
                            ffi::lua_rawseti(state, -2, (i + 1) as i64);
                        }
                        ffi::lua_setfield(state, -2, cstr!("data"));
                    }
                    sheet_counter += 1;
                    unsafe {
                        ffi::lua_rawseti(state, -2, sheet_counter as i64);
                    }
                }
            });
            1
        }
        Err(err) => unsafe {
            ffi::lua_pushboolean(state, 0);
            laux::lua_push(state, format!("{}", err).as_str());
            2
        },
    }
}

extern "C-unwind" fn lua_excel_read(state: *mut ffi::lua_State) -> c_int {
    let filename: &str = laux::lua_get(state, 1);
    let max_row: usize = laux::lua_opt(state, 2).unwrap_or(usize::MAX);
    let path = Path::new(filename);

    match path.extension() {
        Some(ext) => {
            let ext = ext.to_string_lossy().to_string();
            match ext.as_str() {
                "csv" => read_csv(state, path, max_row),
                "xlsx" => read_xlxs(state, path, max_row),
                _ => unsafe {
                    ffi::lua_pushboolean(state, 0);
                    laux::lua_push(state, format!("unsupport file type: {}", ext));
                    2
                },
            }
        }
        None => unsafe {
            ffi::lua_pushboolean(state, 0);
            laux::lua_push(
                state,
                format!("unsupport file type: {}", path.to_string_lossy()),
            );
            2
        },
    }
}

/// # Safety
///
/// This function is unsafe because it dereferences a raw pointer `state`.
/// The caller must ensure that `state` is a valid pointer to a `lua_State`
/// and that it remains valid for the duration of the function call.
#[no_mangle]
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub unsafe extern "C-unwind" fn luaopen_rust_excel(state: *mut ffi::lua_State) -> c_int {
    let l = [lreg!("read", lua_excel_read), lreg_null!()];

    ffi::lua_createtable(state, 0, l.len() as c_int);
    ffi::luaL_setfuncs(state, l.as_ptr(), 0);

    1
}

编写异步库

对于带异步调用的库,一般是和框架的事件循环相关,由于moon是基于Actor模型的,一切皆消息, 只需要把发送消息的函数, 导出给Rust就可以接入到moon的事件循环系统中。

//导出函数
extern "C" {
void MOON_EXPORT
send_message(uint8_t type, uint32_t receiver, int64_t session, const char* data, size_t len) {
    auto svr = wk_server.lock();
    if (nullptr == svr)
        return;
    moon::message msg(len);
    msg.set_type(type);
    msg.set_receiver(receiver);
    msg.set_sessionid(session);
    msg.write_data(std::string_view(data, len));
    svr->send_message(std::move(msg));
}
}

rust中使用导出的函数

unsafe extern "C-unwind" {
    unsafe fn send_message(type_: u8, receiver: u32, session: i64, data: *const i8, len: usize);
}

pub fn moon_send<T>(
    protocol_type: u8,
    owner: u32,
    session: i64,
    res: T,
) {
    if session == 0 {
        return;
    }
    let ptr = Box::into_raw(Box::new(res));
    let bytes = (ptr as isize).to_ne_bytes();

    unsafe {
        send_message(
            protocol_type,
            owner,
            session,
            bytes.as_ptr() as *const i8,
            bytes.len(),
        );
    }
}

pub fn moon_send_string(
    protocol_type: u8,
    owner: u32,
    session: i64,
    data: String,
) {
    unsafe {
        send_message(
            protocol_type,
            owner,
            session,
            data.as_ptr() as *const i8,
            data.len(),
        );
    }
}

pub const PTYPE_ERROR: u8 = 4;
pub const PTYPE_LOG: u8 = 13;

pub const LOG_LEVEL_ERROR: u8 = 1;
pub const LOG_LEVEL_WARN: u8 = 2;
pub const LOG_LEVEL_INFO: u8 = 3;
pub const LOG_LEVEL_DEBUG: u8 = 4;

pub fn moon_log(owner: u32, log_level: u8, data: String) {
    let message = format!("{}{}", log_level, data);
    unsafe {
        send_message(
            PTYPE_LOG,
            owner,
            0,
            message.as_ptr() as *const i8,
            message.len(),
        );
    }
}

注意: 带异步运行时的Rust扩展库不能随Lua虚拟机关闭而卸载,这里需要修改Lua源码, 取消dlclose(lib)/FreeLibrary((HMODULE)lib).

这里拿lua_http库举例

use lib_core::context::CONTEXT;
use lib_lua::{
    self, cstr,
    ffi::{self, luaL_Reg},
    laux, lreg, lreg_null, lua_rawsetfield,
};
use reqwest::{header::HeaderMap, Method, Response};
use std::{error::Error, ffi::c_int, str::FromStr};
use url::form_urlencoded::{self};

use crate::{moon_send, moon_send_string, PTYPE_ERROR};

struct HttpRequest {
    owner: u32,
    session: i64,
    method: String,
    url: String,
    body: String,
    headers: HeaderMap,
    timeout: u64,
    proxy: String,
}

fn version_to_string(version: &reqwest::Version) -> &str {
    match *version {
        reqwest::Version::HTTP_09 => "HTTP/0.9",
        reqwest::Version::HTTP_10 => "HTTP/1.0",
        reqwest::Version::HTTP_11 => "HTTP/1.1",
        reqwest::Version::HTTP_2 => "HTTP/2.0",
        reqwest::Version::HTTP_3 => "HTTP/3.0",
        _ => "Unknown",
    }
}

async fn http_request(
    req: HttpRequest,
    protocol_type: u8,
) -> Result<(), Box<dyn Error>> {
    let http_client = &CONTEXT.get_http_client(req.timeout, &req.proxy);

    let response = http_client
        .request(Method::from_str(req.method.as_str())?, req.url)
        .headers(req.headers)
        .body(req.body)
        .send()
        .await?;

    moon_send(protocol_type, req.owner, req.session, response);

    Ok(())
}

fn extract_headers(state: *mut ffi::lua_State, index: i32) -> Result<HeaderMap, String> {
    let mut headers = HeaderMap::new();

    laux::push_c_string(state, cstr!("headers"));
    if laux::lua_rawget(state, index) == ffi::LUA_TTABLE {
        // [+1]
        laux::lua_pushnil(state);
        while laux::lua_next(state, -2) {
            let key: &str = laux::lua_opt(state, -2).unwrap_or_default();
            let value: &str = laux::lua_opt(state, -1).unwrap_or_default();
            match key.parse::<reqwest::header::HeaderName>() {
                Ok(name) => match value.parse::<reqwest::header::HeaderValue>() {
                    Ok(value) => {
                        headers.insert(name, value);
                    }
                    Err(err) => return Err(err.to_string()),
                },
                Err(err) => return Err(err.to_string()),
            }
            laux::lua_pop(state, 1);
        }
        laux::lua_pop(state, 1); //pop headers table
    }

    Ok(headers)
}

extern "C-unwind" fn lua_http_request(state: *mut ffi::lua_State) -> c_int {
    laux::lua_checktype(state, 1, ffi::LUA_TTABLE);

    let protocol_type = laux::lua_get::<u8>(state, 2);

    let headers = match extract_headers(state, 1) {
        Ok(headers) => headers,
        Err(err) => {
            laux::lua_push(state, false);
            laux::lua_push(state, err);
            return 2;
        }
    };

    let session = laux::opt_field(state, 1, "session").unwrap_or(0);

    let req = HttpRequest {
        owner: laux::opt_field(state, 1, "owner").unwrap_or_default(),
        session,
        method: laux::opt_field(state, 1, "method").unwrap_or("GET".to_string()),
        url: laux::opt_field(state, 1, "url").unwrap_or_default(),
        body: laux::opt_field(state, 1, "body").unwrap_or_default(),
        headers,
        timeout: laux::opt_field(state, 1, "timeout").unwrap_or(5),
        proxy: laux::opt_field(state, 1, "proxy").unwrap_or_default(),
    };

    if let Some(runtime) = CONTEXT.get_tokio_runtime().as_ref() {
        runtime.spawn(async move {
            let session = req.session;
            let owner = req.owner;
            if let Err(err) = http_request(req, protocol_type).await {
                let err_string = err.to_string();
                moon_send_string(
                    PTYPE_ERROR,
                    owner,
                    session,
                    err_string
                );
            }
        });
    } else {
        laux::lua_push(state, false);
        laux::lua_push(state, "No tokio runtime");
        return 2;
    }

    laux::lua_push(state, session);
    1
}

extern "C-unwind" fn lua_http_form_urlencode(state: *mut ffi::lua_State) -> c_int {
    laux::lua_checktype(state, 1, ffi::LUA_TTABLE);
    laux::lua_pushnil(state);
    let mut result = String::new();
    while laux::lua_next(state, 1) {
        if !result.is_empty() {
            result.push('&');
        }
        let key = laux::to_string_unchecked(state, -2);
        let value = laux::to_string_unchecked(state, -1);
        result.push_str(
            form_urlencoded::byte_serialize(key.as_bytes())
                .collect::<String>()
                .as_str(),
        );
        result.push('=');
        result.push_str(
            form_urlencoded::byte_serialize(value.as_bytes())
                .collect::<String>()
                .as_str(),
        );
        laux::lua_pop(state, 1);
    }
    laux::lua_push(state, result);
    1
}

extern "C-unwind" fn lua_http_form_urldecode(state: *mut ffi::lua_State) -> c_int {
    let query_string = laux::lua_get::<&str>(state, 1);

    unsafe { ffi::lua_createtable(state, 0, 8) };

    let decoded: Vec<(String, String)> = form_urlencoded::parse(query_string.as_bytes())
        .into_owned()
        .collect();

    for pair in decoded {
        laux::lua_push(state, pair.0);
        laux::lua_push(state, pair.1);
        unsafe {
            ffi::lua_rawset(state, -3);
        }
    }
    1
}

extern "C-unwind" fn decode(state: *mut ffi::lua_State) -> c_int {
    let bytes = laux::lua_from_raw_parts(state, 1);
    let p_as_isize = isize::from_ne_bytes(bytes.try_into().expect("slice with incorrect length"));
    let response = unsafe { Box::from_raw(p_as_isize as *mut Response) };

    unsafe {
        ffi::lua_createtable(state, 0, 6);
        lua_rawsetfield!(
            state,
            -1,
            "version",
            laux::lua_push(state, version_to_string(&response.version()))
        );
        lua_rawsetfield!(
            state,
            -1,
            "status_code",
            laux::lua_push(state, response.status().as_u16() as u32)
        );

        ffi::lua_pushstring(state, cstr!("headers"));
        ffi::lua_createtable(state, 0, 16);

        for (key, value) in response.headers().iter() {
            laux::lua_push(state, key.to_string().to_lowercase());
            laux::lua_push(state, value.to_str().unwrap_or("").trim());
            ffi::lua_rawset(state, -3);
        }
        ffi::lua_rawset(state, -3);
    }
    1
}

/// # Safety
///
/// This function is unsafe because it dereferences a raw pointer `state`.
/// The caller must ensure that `state` is a valid pointer to a `lua_State`
/// and that it remains valid for the duration of the function call.
#[no_mangle]
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub unsafe extern "C-unwind" fn luaopen_rust_httpc(state: *mut ffi::lua_State) -> c_int {
    let l = [
        lreg!("request", lua_http_request),
        lreg!("form_urlencode", lua_http_form_urlencode),
        lreg!("form_urldecode", lua_http_form_urldecode),
        lreg!("decode", decode),
        lreg_null!(),
    ];

    ffi::lua_createtable(state, 0, l.len() as c_int);
    ffi::luaL_setfuncs(state, l.as_ptr(), 0);

    1
}

这样就完成了Rust异步库和moon的集成, Lua层包装代码

---@diagnostic disable: inject-field
local moon = require "moon"
local json = require "json"
local c = require "rust.httpc"

local protocol_type = 21
local callback = _G['send_message'] -- 获取moon 发送消息的指针

moon.register_protocol { -- 注册协议处理函数
    name = "http",
    PTYPE = protocol_type,
    pack = function(...) return ... end,
    unpack = function (sz, len)
        return c.decode(sz, len) -- 把rust对象解析为lua table
    end
}

---@return table
local function tojson(response)
    if response.status_code ~= 200 then return {} end
    return json.decode(response.body)
end

---@class HttpRequestOptions
---@field headers? table<string,string>
---@field timeout? integer Request timeout in seconds. default 5s
---@field proxy? string

local client = {}

---@param url string
---@param opts? HttpRequestOptions
---@return HttpResponse
function client.get(url, opts)
    opts = opts or {}
    opts.owner = moon.id
    opts.session = moon.next_sequence()
    opts.url = url
    opts.method = "GET"
    return moon.wait(c.request(opts, protocol_type, callback))
end

local json_content_type = { ["Content-Type"] = "application/json" }

---@param url string
---@param data table
---@param opts? HttpRequestOptions
---@return HttpResponse
function client.post_json(url, data, opts)
    opts = opts or {}
    opts.owner = moon.id
    opts.session = moon.next_sequence()
    if not opts.headers then
        opts.headers = json_content_type
    else
        if not opts.headers['Content-Type'] then
            opts.headers['Content-Type'] = "application/json"
        end
    end

    opts.url = url
    opts.method = "POST"
    opts.body = json.encode(data)

    local res = moon.wait(c.request(opts, protocol_type, callback))

    if res.status_code == 200 then
        res.body = tojson(res)
    end
    return res
end

---@param url string
---@param data string
---@param opts? HttpRequestOptions
---@return HttpResponse
function client.post(url, data, opts)
    opts = opts or {}
    opts.owner = moon.id
    opts.session = moon.next_sequence()
    opts.url = url
    opts.body = data
    opts.method = "POST"
    return moon.wait(c.request(opts, protocol_type, callback))
end

local form_headers = { ["Content-Type"] = "application/x-www-form-urlencoded" }

---@param url string
---@param data table<string,string>
---@param opts? HttpRequestOptions
---@return HttpResponse
function client.post_form(url, data, opts)
    opts = opts or {}
    opts.owner = moon.id
    opts.session = moon.next_sequence()
    if not opts.headers then
        opts.headers = form_headers
    else
        if not opts.headers['Content-Type'] then
            opts.headers['Content-Type'] = "application/x-www-form-urlencoded"
        end
    end

    opts.body = {}
    for k, v in pairs(data) do
        opts.body[k] = tostring(v)
    end

    opts.url = url
    opts.method = "POST"
    opts.body = c.form_urlencode(opts.body)

    return moon.wait(c.request(opts, protocol_type, callback))
end

return client