Skip to content

Commit

Permalink
refactor: migrate to mlua
Browse files Browse the repository at this point in the history
  • Loading branch information
oddlama committed Sep 8, 2024
1 parent 393f268 commit df50719
Showing 1 changed file with 153 additions and 165 deletions.
318 changes: 153 additions & 165 deletions src/script/lua.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::path::Path;
use std::result::Result::{Err as StdErr, Ok as StdOk};

use anyhow::{Context, Ok, Result};
use rlua::{self, Error as LuaError, Lua};
use mlua::{self, Error as LuaError, ExternalResult, Lua, LuaOptions, StdLib};

pub struct LuaScript {
lua: Lua,
Expand All @@ -18,192 +18,180 @@ pub struct LuaScript {

impl LuaScript {
pub fn new(file: impl AsRef<Path>) -> Result<LuaScript> {
Ok(LuaScript::from_raw(
LuaScript::from_raw(
file.as_ref().display().to_string(),
fs::read_to_string(&file).context(format!("Could not read lua script {}", file.as_ref().display()))?,
))
)
}

pub fn from_raw(filename: String, code: String) -> LuaScript {
LuaScript {
lua: unsafe { Lua::new_with_debug() },
pub fn from_raw(filename: String, code: String) -> Result<LuaScript> {
Ok(LuaScript {
lua: Lua::new_with(StdLib::DEBUG, LuaOptions::default())?,
filename,
code,
}
})
}
}

impl Script for LuaScript {
fn apply(&self, bridge: &Bridge) -> Result<()> {
self.lua.context(|lua_ctx| {
lua_ctx.scope(|scope| {
let globals = lua_ctx.globals();
let symbol_set_auto = scope.create_function(
|_, (name, value, file, line, traceback): (String, String, String, u32, String)| {
bridge
.symbol(&name)
.unwrap()
.set_value_tracked(SymbolValue::Auto(value), file, line, Some(traceback))
.ok();
StdOk(())
},
)?;
let symbol_set_bool = scope.create_function(
|_, (name, value, file, line, traceback): (String, bool, String, u32, String)| {
bridge
.symbol(&name)
.unwrap()
.set_value_tracked(SymbolValue::Boolean(value), file, line, Some(traceback))
.ok();
StdOk(())
},
)?;
let symbol_set_number = scope.create_function(
|_, (name, value, file, line, traceback): (String, i64, String, u32, String)| {
// We use an i64 here to detect whether values in lua got clipped. Apparently
// when values wrap
if value < 0 {
return StdErr(LuaError::RuntimeError(
"Please pass values >=2*63 in string syntax. lua doesn't support this.".to_string(),
));
}
bridge
.symbol(&name)
.unwrap()
.set_value_tracked(SymbolValue::Number(value as u64), file, line, Some(traceback))
.ok();
StdOk(())
},
)?;
let symbol_set_tristate = scope.create_function(
|_, (name, value, file, line, traceback): (String, String, String, u32, String)| {
self.lua.scope(|scope| {
let symbol_set_auto = scope.create_function(
|_, (name, value, file, line, traceback): (String, String, String, u32, String)| {
bridge
.symbol(&name)
.unwrap()
.set_value_tracked(SymbolValue::Auto(value), file, line, Some(traceback))
.ok();
StdOk(())
},
)?;
let symbol_set_bool = scope.create_function(
|_, (name, value, file, line, traceback): (String, bool, String, u32, String)| {
bridge
.symbol(&name)
.unwrap()
.set_value_tracked(SymbolValue::Boolean(value), file, line, Some(traceback))
.ok();
StdOk(())
},
)?;
let symbol_set_number = scope.create_function(
|_, (name, value, file, line, traceback): (String, i64, String, u32, String)| {
// We use an i64 here to detect whether values in lua got clipped. Apparently
// when values wrap
if value < 0 {
return StdErr(LuaError::RuntimeError(
"Please pass values >=2*63 in string syntax. lua doesn't support this.".to_string(),
));
}
bridge
.symbol(&name)
.unwrap()
.set_value_tracked(SymbolValue::Number(value as u64), file, line, Some(traceback))
.ok();
StdOk(())
},
)?;
let symbol_set_tristate = scope.create_function(
|_, (name, value, file, line, traceback): (String, String, String, u32, String)| {
bridge
.symbol(&name)
.unwrap()
.set_value_tracked(
SymbolValue::Tristate(value.parse().map_err(|_| {
LuaError::RuntimeError(format!("Could not convert {value} to tristate"))
})?),
file,
line,
Some(traceback),
)
.ok();
StdOk(())
},
)?;
let symbol_satisfy_and_set = scope.create_function(
|_, (name, value, recursive, file, line, traceback): (String, String, bool, String, u32, String)| {
let value = value
.parse()
.map_err(|_| LuaError::RuntimeError(format!("Could not convert {value} to tristate")))?;
let satisfying_configuration = bridge.symbol(&name).unwrap().satisfy_track_error(
SymbolValue::Tristate(value),
file.clone(),
line,
Some(traceback.clone()),
SolverConfig {
recursive,
desired_value: value,
..SolverConfig::default()
},
);

// If there was an error, it will have been tracked already.
// Ignore and continue.
if satisfying_configuration.is_err() {
return StdOk(());
}

for (sym, value) in satisfying_configuration.unwrap() {
bridge
.symbol(&name)
.symbol(&sym)
.unwrap()
.set_value_tracked(
SymbolValue::Tristate(value.parse().map_err(|_| {
LuaError::RuntimeError(format!("Could not convert {value} to tristate"))
})?),
file,
line,
Some(traceback),
)
.ok();
StdOk(())
},
)?;
let symbol_satisfy_and_set =
scope.create_function(
|_,
(name, value, recursive, file, line, traceback): (
String,
String,
bool,
String,
u32,
String,
)| {
let value = value.parse().map_err(|_| {
LuaError::RuntimeError(format!("Could not convert {value} to tristate"))
})?;
let satisfying_configuration = bridge.symbol(&name).unwrap().satisfy_track_error(
SymbolValue::Tristate(value),
file.clone(),
line,
Some(traceback.clone()),
SolverConfig {
recursive,
desired_value: value,
..SolverConfig::default()
},
);

// If there was an error, it will have been tracked already.
// Ignore and continue.
if satisfying_configuration.is_err() {
return StdOk(());
}

for (sym, value) in satisfying_configuration.unwrap() {
bridge
.symbol(&sym)
.unwrap()
.set_value_tracked(
SymbolValue::Tristate(value),
file.clone(),
line,
Some(traceback.clone()),
)
.ok();
}

let mut symbol = bridge.symbol(&name).unwrap();
if symbol.prompt_count() > 0 {
symbol
.set_value_tracked(SymbolValue::Tristate(value), file, line, Some(traceback))
.ok();
}

StdOk(())
},
)?;
let symbol_get_string =
scope.create_function(|_, name: String| StdOk(bridge.symbol(&name).unwrap().get_string_value()))?;
let symbol_get_type = scope.create_function(|_, name: String| {
StdOk(format!("{:?}", bridge.symbol(&name).unwrap().symbol_type()))
})?;

let load_kconfig = scope.create_function(|_, (path, checked): (String, bool)| {
if checked {
KConfig::new(path)
.map_err(|e| LuaError::RuntimeError(e.to_string()))?
.apply(bridge)
)
.ok();
}

let mut symbol = bridge.symbol(&name).unwrap();
if symbol.prompt_count() > 0 {
symbol
.set_value_tracked(SymbolValue::Tristate(value), file, line, Some(traceback))
.ok();
// Errors will be tracked automatically
StdOk(())
} else {
bridge
.read_config_unchecked(path)
.map_err(|e| LuaError::RuntimeError(e.to_string()))
}
})?;

let kernel_env = scope.create_function(|_, name: String| StdOk(bridge.get_env(&name)))?;

let ak = lua_ctx.create_table()?;
ak.set("kernel_dir", bridge.kernel_dir.to_str())?;
ak.set("kernel_version_str", bridge.get_env("KERNELVERSION"))?;
ak.set("symbol_set_auto", symbol_set_auto)?;
ak.set("symbol_set_bool", symbol_set_bool)?;
ak.set("symbol_set_number", symbol_set_number)?;
ak.set("symbol_set_tristate", symbol_set_tristate)?;
ak.set("symbol_satisfy_and_set", symbol_satisfy_and_set)?;
ak.set("symbol_get_string", symbol_get_string)?;
ak.set("symbol_get_type", symbol_get_type)?;
ak.set("load_kconfig", load_kconfig)?;
ak.set("kernel_env", kernel_env)?;
globals.set("ak", ak)?;

lua_ctx.load(include_bytes!("api.lua")).set_name("api.lua")?.exec()?;

let mut define_all_syms = String::new();
for name in bridge.name_to_symbol.keys() {
let has_uppercase_char = name.chars().any(|c| c.is_ascii_uppercase());
if !name.is_empty() && has_uppercase_char {
writeln!(define_all_syms, "CONFIG_{name} = Symbol:new(nil, \"{name}\")")?;
if !name.chars().next().unwrap().is_ascii_digit() {
writeln!(define_all_syms, "{name} = CONFIG_{name}")?;
}

StdOk(())
},
)?;
let symbol_get_string =
scope.create_function(|_, name: String| StdOk(bridge.symbol(&name).unwrap().get_string_value()))?;
let symbol_get_type = scope.create_function(|_, name: String| {
StdOk(format!("{:?}", bridge.symbol(&name).unwrap().symbol_type()))
})?;

let load_kconfig = scope.create_function(|_, (path, checked): (String, bool)| {
if checked {
KConfig::new(path)
.map_err(|e| LuaError::RuntimeError(e.to_string()))?
.apply(bridge)
.ok();
// Errors will be tracked automatically
StdOk(())
} else {
bridge
.read_config_unchecked(path)
.map_err(|e| LuaError::RuntimeError(e.to_string()))
}
})?;

let kernel_env = scope.create_function(|_, name: String| StdOk(bridge.get_env(&name)))?;

let ak = self.lua.create_table()?;
ak.set("kernel_dir", bridge.kernel_dir.to_str())?;
ak.set("kernel_version_str", bridge.get_env("KERNELVERSION"))?;
ak.set("symbol_set_auto", symbol_set_auto)?;
ak.set("symbol_set_bool", symbol_set_bool)?;
ak.set("symbol_set_number", symbol_set_number)?;
ak.set("symbol_set_tristate", symbol_set_tristate)?;
ak.set("symbol_satisfy_and_set", symbol_satisfy_and_set)?;
ak.set("symbol_get_string", symbol_get_string)?;
ak.set("symbol_get_type", symbol_get_type)?;
ak.set("load_kconfig", load_kconfig)?;
ak.set("kernel_env", kernel_env)?;
self.lua.globals().set("ak", ak)?;

self.lua.load(include_str!("api.lua")).set_name("api.lua").exec()?;

let mut define_all_syms = String::new();
for name in bridge.name_to_symbol.keys() {
let has_uppercase_char = name.chars().any(|c| c.is_ascii_uppercase());
if !name.is_empty() && has_uppercase_char {
writeln!(define_all_syms, "CONFIG_{name} = Symbol:new(nil, \"{name}\")").into_lua_err()?;
if !name.chars().next().unwrap().is_ascii_digit() {
writeln!(define_all_syms, "{name} = CONFIG_{name}").into_lua_err()?;
}
}
lua_ctx
.load(&define_all_syms)
.set_name("<internal>::define_all_syms")?
.exec()?;

lua_ctx.load(&self.code).set_name(&self.filename)?.exec()?;
Ok(())
})
}
self.lua
.load(&define_all_syms)
.set_name("<internal>::define_all_syms")
.exec()?;

self.lua.load(&self.code).set_name(&self.filename).exec()?;
core::result::Result::Ok(())
})?;

Ok(())
Expand Down

0 comments on commit df50719

Please sign in to comment.