-
Notifications
You must be signed in to change notification settings - Fork 175
Writing Lua extension library using Rust
Bruce edited this page Dec 24, 2024
·
2 revisions
Rust有完善的包管理机制,借助Rust可以极大丰富Lua扩展库,如依赖tokio
的https client
,sqlx
等网络相关的库。使用Rust编写Lua扩展库也是比较容易的, 因为Rust本身提供编写动态库给C/Cpp调用。
编写Lua扩展库需要依赖lib-lua-sys, 这里参考了mlua-sys(Low level (FFI) bindings to Lua 5.4/5.3/5.2/5.1 (including LuaJIT) and Roblox Luau)。 Rust mlua库本身是支持编写lua扩展库的, 但它比较复杂,并且使用起来没有 lua api灵活,所以这里只使用它的 ffi bingding部分。同时lib-lua-sys也做了一些改动, mlua默认是静态link的lua库, 由于要给moon编写扩展库lib-lua-sys是动态link lua库的。
新建Rust项目, 手动添加lib-lua-sys。具体Rust包管理机制请参考Rust相关文档,这里不再详细描述。
[dependencies]
lib-core = { path = "../../libs/lib-core"}
lib-lua = {package = "lib-lua-sys", path = "../../libs/lib-lua-sys",features = ["lua54"]}
然后就可以使用类似 lua c api
的方式编写lua扩展库了, 如lua_excel
举例
use calamine::{open_workbook, Data, Reader, Xlsx};
use csv::ReaderBuilder;
use lib_lua::{self, cstr, ffi, ffi::luaL_Reg, laux, lreg, lreg_null};
use std::{os::raw::c_int, path::Path};
fn read_csv(state: *mut ffi::lua_State, path: &Path, max_row: usize) -> c_int {
let res = ReaderBuilder::new().has_headers(false).from_path(path);
unsafe {
ffi::lua_createtable(state, 0, 0);
}
match res {
Ok(mut reader) => {
unsafe {
ffi::lua_createtable(state, 0, 2);
laux::lua_push(
state,
path.file_stem()
.unwrap_or_default()
.to_str()
.unwrap_or_default(),
);
ffi::lua_setfield(state, -2, cstr!("sheet_name"));
ffi::lua_createtable(state, 1024, 0);
}
let mut idx: usize = 0;
for result in reader.records() {
if idx >= max_row {
break;
}
match result {
Ok(record) => unsafe {
ffi::lua_createtable(state, 0, record.len() as i32);
for (i, field) in record.iter().enumerate() {
laux::lua_push(state, field);
ffi::lua_rawseti(state, -2, (i + 1) as i64);
}
idx += 1;
ffi::lua_rawseti(state, -2, idx as i64);
},
Err(err) => unsafe {
ffi::lua_pushboolean(state, 0);
laux::lua_push(
state,
format!("read csv '{}' error: {}", path.to_string_lossy(), err)
.as_str(),
);
return 2;
},
}
}
unsafe {
ffi::lua_setfield(state, -2, cstr!("data"));
ffi::lua_rawseti(state, -2, 1);
}
1
}
Err(err) => {
unsafe {
ffi::lua_pushboolean(state, 0);
}
laux::lua_push(
state,
format!("open file '{}' error: {}", path.to_string_lossy(), err).as_str(),
);
2
}
}
}
fn read_xlxs(state: *mut ffi::lua_State, path: &Path, max_row: usize) -> c_int {
let res: Result<Xlsx<_>, _> = open_workbook(path);
match res {
Ok(mut workbook) => {
unsafe {
ffi::lua_createtable(state, 0, 0);
}
let mut sheet_counter = 0;
workbook.sheet_names().iter().for_each(|sheet| {
if let Ok(range) = workbook.worksheet_range(sheet) {
unsafe {
ffi::lua_createtable(state, 0, 2);
laux::lua_push(state, sheet.as_str());
ffi::lua_setfield(state, -2, cstr!("sheet_name"));
ffi::lua_createtable(state, range.rows().len() as i32, 0);
for (i, row) in range.rows().enumerate() {
if i >= max_row {
break;
}
//rows
ffi::lua_createtable(state, row.len() as i32, 0);
for (j, cell) in row.iter().enumerate() {
//columns
match cell {
Data::Int(v) => {
ffi::lua_pushinteger(state, *v as ffi::lua_Integer)
}
Data::Float(v) => ffi::lua_pushnumber(state, *v),
Data::String(v) => laux::lua_push(state, v.as_str()),
Data::Bool(v) => ffi::lua_pushboolean(state, *v as i32),
Data::Error(v) => laux::lua_push(state, v.to_string()),
Data::Empty => ffi::lua_pushnil(state),
Data::DateTime(v) => laux::lua_push(state, v.to_string()),
_ => ffi::lua_pushnil(state),
}
ffi::lua_rawseti(state, -2, (j + 1) as i64);
}
ffi::lua_rawseti(state, -2, (i + 1) as i64);
}
ffi::lua_setfield(state, -2, cstr!("data"));
}
sheet_counter += 1;
unsafe {
ffi::lua_rawseti(state, -2, sheet_counter as i64);
}
}
});
1
}
Err(err) => unsafe {
ffi::lua_pushboolean(state, 0);
laux::lua_push(state, format!("{}", err).as_str());
2
},
}
}
extern "C-unwind" fn lua_excel_read(state: *mut ffi::lua_State) -> c_int {
let filename: &str = laux::lua_get(state, 1);
let max_row: usize = laux::lua_opt(state, 2).unwrap_or(usize::MAX);
let path = Path::new(filename);
match path.extension() {
Some(ext) => {
let ext = ext.to_string_lossy().to_string();
match ext.as_str() {
"csv" => read_csv(state, path, max_row),
"xlsx" => read_xlxs(state, path, max_row),
_ => unsafe {
ffi::lua_pushboolean(state, 0);
laux::lua_push(state, format!("unsupport file type: {}", ext));
2
},
}
}
None => unsafe {
ffi::lua_pushboolean(state, 0);
laux::lua_push(
state,
format!("unsupport file type: {}", path.to_string_lossy()),
);
2
},
}
}
/// # Safety
///
/// This function is unsafe because it dereferences a raw pointer `state`.
/// The caller must ensure that `state` is a valid pointer to a `lua_State`
/// and that it remains valid for the duration of the function call.
#[no_mangle]
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub unsafe extern "C-unwind" fn luaopen_rust_excel(state: *mut ffi::lua_State) -> c_int {
let l = [lreg!("read", lua_excel_read), lreg_null!()];
ffi::lua_createtable(state, 0, l.len() as c_int);
ffi::luaL_setfuncs(state, l.as_ptr(), 0);
1
}
对于带异步调用的库,一般是和框架的事件循环相关,由于moon是基于Actor模型的,一切皆消息, 只需要把发送消息的函数, 导出给Rust就可以接入到moon的事件循环系统中。
//导出函数
extern "C" {
void MOON_EXPORT
send_message(uint8_t type, uint32_t receiver, int64_t session, const char* data, size_t len) {
auto svr = wk_server.lock();
if (nullptr == svr)
return;
moon::message msg(len);
msg.set_type(type);
msg.set_receiver(receiver);
msg.set_sessionid(session);
msg.write_data(std::string_view(data, len));
svr->send_message(std::move(msg));
}
}
rust中使用导出的函数
unsafe extern "C-unwind" {
unsafe fn send_message(type_: u8, receiver: u32, session: i64, data: *const i8, len: usize);
}
pub fn moon_send<T>(
protocol_type: u8,
owner: u32,
session: i64,
res: T,
) {
if session == 0 {
return;
}
let ptr = Box::into_raw(Box::new(res));
let bytes = (ptr as isize).to_ne_bytes();
unsafe {
send_message(
protocol_type,
owner,
session,
bytes.as_ptr() as *const i8,
bytes.len(),
);
}
}
pub fn moon_send_string(
protocol_type: u8,
owner: u32,
session: i64,
data: String,
) {
unsafe {
send_message(
protocol_type,
owner,
session,
data.as_ptr() as *const i8,
data.len(),
);
}
}
pub const PTYPE_ERROR: u8 = 4;
pub const PTYPE_LOG: u8 = 13;
pub const LOG_LEVEL_ERROR: u8 = 1;
pub const LOG_LEVEL_WARN: u8 = 2;
pub const LOG_LEVEL_INFO: u8 = 3;
pub const LOG_LEVEL_DEBUG: u8 = 4;
pub fn moon_log(owner: u32, log_level: u8, data: String) {
let message = format!("{}{}", log_level, data);
unsafe {
send_message(
PTYPE_LOG,
owner,
0,
message.as_ptr() as *const i8,
message.len(),
);
}
}
注意: 带异步运行时的Rust扩展库不能随Lua虚拟机关闭而卸载,这里需要修改Lua源码, 取消
dlclose(lib)/FreeLibrary((HMODULE)lib)
.
这里拿lua_http库举例
use lib_core::context::CONTEXT;
use lib_lua::{
self, cstr,
ffi::{self, luaL_Reg},
laux, lreg, lreg_null, lua_rawsetfield,
};
use reqwest::{header::HeaderMap, Method, Response};
use std::{error::Error, ffi::c_int, str::FromStr};
use url::form_urlencoded::{self};
use crate::{moon_send, moon_send_string, PTYPE_ERROR};
struct HttpRequest {
owner: u32,
session: i64,
method: String,
url: String,
body: String,
headers: HeaderMap,
timeout: u64,
proxy: String,
}
fn version_to_string(version: &reqwest::Version) -> &str {
match *version {
reqwest::Version::HTTP_09 => "HTTP/0.9",
reqwest::Version::HTTP_10 => "HTTP/1.0",
reqwest::Version::HTTP_11 => "HTTP/1.1",
reqwest::Version::HTTP_2 => "HTTP/2.0",
reqwest::Version::HTTP_3 => "HTTP/3.0",
_ => "Unknown",
}
}
async fn http_request(
req: HttpRequest,
protocol_type: u8,
) -> Result<(), Box<dyn Error>> {
let http_client = &CONTEXT.get_http_client(req.timeout, &req.proxy);
let response = http_client
.request(Method::from_str(req.method.as_str())?, req.url)
.headers(req.headers)
.body(req.body)
.send()
.await?;
moon_send(protocol_type, req.owner, req.session, response);
Ok(())
}
fn extract_headers(state: *mut ffi::lua_State, index: i32) -> Result<HeaderMap, String> {
let mut headers = HeaderMap::new();
laux::push_c_string(state, cstr!("headers"));
if laux::lua_rawget(state, index) == ffi::LUA_TTABLE {
// [+1]
laux::lua_pushnil(state);
while laux::lua_next(state, -2) {
let key: &str = laux::lua_opt(state, -2).unwrap_or_default();
let value: &str = laux::lua_opt(state, -1).unwrap_or_default();
match key.parse::<reqwest::header::HeaderName>() {
Ok(name) => match value.parse::<reqwest::header::HeaderValue>() {
Ok(value) => {
headers.insert(name, value);
}
Err(err) => return Err(err.to_string()),
},
Err(err) => return Err(err.to_string()),
}
laux::lua_pop(state, 1);
}
laux::lua_pop(state, 1); //pop headers table
}
Ok(headers)
}
extern "C-unwind" fn lua_http_request(state: *mut ffi::lua_State) -> c_int {
laux::lua_checktype(state, 1, ffi::LUA_TTABLE);
let protocol_type = laux::lua_get::<u8>(state, 2);
let headers = match extract_headers(state, 1) {
Ok(headers) => headers,
Err(err) => {
laux::lua_push(state, false);
laux::lua_push(state, err);
return 2;
}
};
let session = laux::opt_field(state, 1, "session").unwrap_or(0);
let req = HttpRequest {
owner: laux::opt_field(state, 1, "owner").unwrap_or_default(),
session,
method: laux::opt_field(state, 1, "method").unwrap_or("GET".to_string()),
url: laux::opt_field(state, 1, "url").unwrap_or_default(),
body: laux::opt_field(state, 1, "body").unwrap_or_default(),
headers,
timeout: laux::opt_field(state, 1, "timeout").unwrap_or(5),
proxy: laux::opt_field(state, 1, "proxy").unwrap_or_default(),
};
if let Some(runtime) = CONTEXT.get_tokio_runtime().as_ref() {
runtime.spawn(async move {
let session = req.session;
let owner = req.owner;
if let Err(err) = http_request(req, protocol_type).await {
let err_string = err.to_string();
moon_send_string(
PTYPE_ERROR,
owner,
session,
err_string
);
}
});
} else {
laux::lua_push(state, false);
laux::lua_push(state, "No tokio runtime");
return 2;
}
laux::lua_push(state, session);
1
}
extern "C-unwind" fn lua_http_form_urlencode(state: *mut ffi::lua_State) -> c_int {
laux::lua_checktype(state, 1, ffi::LUA_TTABLE);
laux::lua_pushnil(state);
let mut result = String::new();
while laux::lua_next(state, 1) {
if !result.is_empty() {
result.push('&');
}
let key = laux::to_string_unchecked(state, -2);
let value = laux::to_string_unchecked(state, -1);
result.push_str(
form_urlencoded::byte_serialize(key.as_bytes())
.collect::<String>()
.as_str(),
);
result.push('=');
result.push_str(
form_urlencoded::byte_serialize(value.as_bytes())
.collect::<String>()
.as_str(),
);
laux::lua_pop(state, 1);
}
laux::lua_push(state, result);
1
}
extern "C-unwind" fn lua_http_form_urldecode(state: *mut ffi::lua_State) -> c_int {
let query_string = laux::lua_get::<&str>(state, 1);
unsafe { ffi::lua_createtable(state, 0, 8) };
let decoded: Vec<(String, String)> = form_urlencoded::parse(query_string.as_bytes())
.into_owned()
.collect();
for pair in decoded {
laux::lua_push(state, pair.0);
laux::lua_push(state, pair.1);
unsafe {
ffi::lua_rawset(state, -3);
}
}
1
}
extern "C-unwind" fn decode(state: *mut ffi::lua_State) -> c_int {
let bytes = laux::lua_from_raw_parts(state, 1);
let p_as_isize = isize::from_ne_bytes(bytes.try_into().expect("slice with incorrect length"));
let response = unsafe { Box::from_raw(p_as_isize as *mut Response) };
unsafe {
ffi::lua_createtable(state, 0, 6);
lua_rawsetfield!(
state,
-1,
"version",
laux::lua_push(state, version_to_string(&response.version()))
);
lua_rawsetfield!(
state,
-1,
"status_code",
laux::lua_push(state, response.status().as_u16() as u32)
);
ffi::lua_pushstring(state, cstr!("headers"));
ffi::lua_createtable(state, 0, 16);
for (key, value) in response.headers().iter() {
laux::lua_push(state, key.to_string().to_lowercase());
laux::lua_push(state, value.to_str().unwrap_or("").trim());
ffi::lua_rawset(state, -3);
}
ffi::lua_rawset(state, -3);
}
1
}
/// # Safety
///
/// This function is unsafe because it dereferences a raw pointer `state`.
/// The caller must ensure that `state` is a valid pointer to a `lua_State`
/// and that it remains valid for the duration of the function call.
#[no_mangle]
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub unsafe extern "C-unwind" fn luaopen_rust_httpc(state: *mut ffi::lua_State) -> c_int {
let l = [
lreg!("request", lua_http_request),
lreg!("form_urlencode", lua_http_form_urlencode),
lreg!("form_urldecode", lua_http_form_urldecode),
lreg!("decode", decode),
lreg_null!(),
];
ffi::lua_createtable(state, 0, l.len() as c_int);
ffi::luaL_setfuncs(state, l.as_ptr(), 0);
1
}
这样就完成了Rust异步库和moon的集成, Lua层包装代码
---@diagnostic disable: inject-field
local moon = require "moon"
local json = require "json"
local c = require "rust.httpc"
local protocol_type = 21
local callback = _G['send_message'] -- 获取moon 发送消息的指针
moon.register_protocol { -- 注册协议处理函数
name = "http",
PTYPE = protocol_type,
pack = function(...) return ... end,
unpack = function (sz, len)
return c.decode(sz, len) -- 把rust对象解析为lua table
end
}
---@return table
local function tojson(response)
if response.status_code ~= 200 then return {} end
return json.decode(response.body)
end
---@class HttpRequestOptions
---@field headers? table<string,string>
---@field timeout? integer Request timeout in seconds. default 5s
---@field proxy? string
local client = {}
---@param url string
---@param opts? HttpRequestOptions
---@return HttpResponse
function client.get(url, opts)
opts = opts or {}
opts.owner = moon.id
opts.session = moon.next_sequence()
opts.url = url
opts.method = "GET"
return moon.wait(c.request(opts, protocol_type, callback))
end
local json_content_type = { ["Content-Type"] = "application/json" }
---@param url string
---@param data table
---@param opts? HttpRequestOptions
---@return HttpResponse
function client.post_json(url, data, opts)
opts = opts or {}
opts.owner = moon.id
opts.session = moon.next_sequence()
if not opts.headers then
opts.headers = json_content_type
else
if not opts.headers['Content-Type'] then
opts.headers['Content-Type'] = "application/json"
end
end
opts.url = url
opts.method = "POST"
opts.body = json.encode(data)
local res = moon.wait(c.request(opts, protocol_type, callback))
if res.status_code == 200 then
res.body = tojson(res)
end
return res
end
---@param url string
---@param data string
---@param opts? HttpRequestOptions
---@return HttpResponse
function client.post(url, data, opts)
opts = opts or {}
opts.owner = moon.id
opts.session = moon.next_sequence()
opts.url = url
opts.body = data
opts.method = "POST"
return moon.wait(c.request(opts, protocol_type, callback))
end
local form_headers = { ["Content-Type"] = "application/x-www-form-urlencoded" }
---@param url string
---@param data table<string,string>
---@param opts? HttpRequestOptions
---@return HttpResponse
function client.post_form(url, data, opts)
opts = opts or {}
opts.owner = moon.id
opts.session = moon.next_sequence()
if not opts.headers then
opts.headers = form_headers
else
if not opts.headers['Content-Type'] then
opts.headers['Content-Type'] = "application/x-www-form-urlencoded"
end
end
opts.body = {}
for k, v in pairs(data) do
opts.body[k] = tostring(v)
end
opts.url = url
opts.method = "POST"
opts.body = c.form_urlencode(opts.body)
return moon.wait(c.request(opts, protocol_type, callback))
end
return client