Skip to content

Commit

Permalink
add data path to daemon to store models in postgres data dir when run…
Browse files Browse the repository at this point in the history
… in bgworker
  • Loading branch information
var77 committed Oct 7, 2024
1 parent 94e2d56 commit 664f1a7
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 16 deletions.
4 changes: 4 additions & 0 deletions lantern_cli/src/daemon/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ pub struct DaemonArgs {
#[arg(long)]
pub label: Option<String>,

/// Data path
#[arg(long)]
pub data_path: Option<String>,

/// Log level
#[arg(long, value_enum, default_value_t = LogLevel::Info)] // arg_enum here
pub log_level: LogLevel,
Expand Down
32 changes: 19 additions & 13 deletions lantern_cli/src/daemon/embedding_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ async fn job_insert_processor(
schema: String,
table: String,
daemon_label: String,
data_path: &'static str,
data_path: String,
jobs_map: Arc<JobEventHandlersMap>,
job_batching_hashmap: Arc<JobBatchingHashMap>,
client_jobs_map: Arc<ClientJobsMap>,
Expand Down Expand Up @@ -627,6 +627,7 @@ async fn job_insert_processor(
let logger_r1 = logger.clone();
let lock_table_name = Arc::new(get_full_table_name(&schema, EMB_LOCK_TABLE_NAME));
let job_batching_hashmap_r1 = job_batching_hashmap.clone();
let data_path_clone = data_path.clone();

let (insert_client, connection) = tokio_postgres::connect(&db_uri_r1, NoTls).await?;
let insert_client = Arc::new(insert_client);
Expand Down Expand Up @@ -689,7 +690,7 @@ async fn job_insert_processor(
.get::<&str, Option<SystemTime>>("init_finished_at")
.is_none();

let job = EmbeddingJob::new(row, data_path, &db_uri_r1);
let job = EmbeddingJob::new(row, &data_path, &db_uri_r1);

if let Err(e) = &job {
logger_r1.error(&format!("Error while creating job {id}: {e}",));
Expand Down Expand Up @@ -770,7 +771,7 @@ async fn job_insert_processor(
continue;
}
let row = job_result.unwrap();
let job = EmbeddingJob::new(row, data_path, &db_uri);
let job = EmbeddingJob::new(row, &data_path_clone, &db_uri);

if let Err(e) = &job {
logger.error(&format!("Error while creating job {job_id}: {e}"));
Expand Down Expand Up @@ -934,33 +935,38 @@ async fn job_update_processor(
}
}

async fn create_data_path(logger: Arc<Logger>) -> Result<&'static str, anyhow::Error> {
let tmp_path = "/tmp/lantern-daemon";
let data_path = if cfg!(target_os = "macos") {
"/usr/local/var/lantern-daemon"
async fn create_data_path(
logger: Arc<Logger>,
data_path: Option<String>,
) -> Result<String, anyhow::Error> {
let tmp_path = "/tmp/lantern-daemon".to_owned();
let local_path = if cfg!(target_os = "macos") {
"/usr/local/var/lantern-daemon".to_owned()
} else {
"/var/lib/lantern-daemon"
"/var/lib/lantern-daemon".to_owned()
};

let data_path_obj = Path::new(data_path);
let data_path = data_path.unwrap_or(local_path);

let data_path_obj = Path::new(&data_path);
if data_path_obj.exists() {
return Ok(data_path);
}

if fs::create_dir(data_path).await.is_ok() {
if fs::create_dir(&data_path).await.is_ok() {
return Ok(data_path);
}

logger.warn(&format!(
"No write permission in directory {data_path}. Writing data to temp directory"
));
let tmp_path_obj = Path::new(tmp_path);
let tmp_path_obj = Path::new(&tmp_path);

if tmp_path_obj.exists() {
return Ok(tmp_path);
}

if let Err(e) = fs::create_dir(tmp_path).await {
if let Err(e) = fs::create_dir(&tmp_path).await {
match e.kind() {
std::io::ErrorKind::AlreadyExists => {}
_ => anyhow::bail!(e),
Expand Down Expand Up @@ -1019,7 +1025,7 @@ pub async fn start(
let connection_task = tokio::spawn(async move { connection.await });

let notification_channel = "lantern_cloud_embedding_jobs_v2";
let data_path = create_data_path(logger.clone()).await?;
let data_path = create_data_path(logger.clone(), args.data_path).await?;

let (insert_notification_queue_tx, insert_notification_queue_rx): (
UnboundedSender<JobInsertNotification>,
Expand Down
3 changes: 3 additions & 0 deletions lantern_cli/src/daemon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ async fn spawn_job(
uri: target_db.uri.clone(),
schema: args.schema.clone(),
log_level: args.log_level.value(),
data_path: args.data_path.clone(),
table_name: "embedding_generation_jobs".to_owned(),
},
processor_tx.clone(),
Expand All @@ -137,6 +138,7 @@ async fn spawn_job(
uri: target_db.uri.clone(),
schema: args.schema.clone(),
log_level: args.log_level.value(),
data_path: None,
table_name: "external_index_jobs".to_owned(),
},
processor_tx.clone(),
Expand All @@ -152,6 +154,7 @@ async fn spawn_job(
uri: target_db.uri.clone(),
schema: args.schema.clone(),
log_level: args.log_level.value(),
data_path: None,
table_name: "autotune_jobs".to_owned(),
},
processor_tx.clone(),
Expand Down
1 change: 1 addition & 0 deletions lantern_cli/src/daemon/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub struct JobRunArgs {
pub log_level: crate::logger::LogLevel,
pub table_name: String,
pub label: Option<String>,
pub data_path: Option<String>,
}

#[derive(Clone, Debug)]
Expand Down
2 changes: 1 addition & 1 deletion lantern_cli/src/embeddings/core/ort_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ pub struct EncoderOptions {
pub input_image_size: Option<usize>,
}

const DATA_PATH: &'static str = ".ldb_extras_data/";
pub const DATA_PATH: &'static str = ".ldb_extras_data/";
const MAX_IMAGE_SIZE: u64 = 1024 * 1024 * 20; // 20 MB

struct ModelInfo {
Expand Down
1 change: 1 addition & 0 deletions lantern_cli/tests/daemon_autotune_test_with_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ async fn test_daemon_autotune_with_create_index() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down
17 changes: 17 additions & 0 deletions lantern_cli/tests/daemon_embeddings_test_with_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async fn test_daemon_embedding_init_job() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -96,6 +97,7 @@ async fn test_daemon_embedding_job_client_insert_listener() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -168,6 +170,7 @@ async fn test_daemon_embedding_job_client_update_listener() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -239,6 +242,7 @@ async fn test_daemon_embedding_job_resume() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -299,6 +303,7 @@ async fn test_daemon_embedding_finished_job_listener() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -375,6 +380,7 @@ async fn test_daemon_embedding_multiple_jobs_listener() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -462,6 +468,7 @@ async fn test_daemon_embedding_multiple_new_jobs_streaming() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -557,6 +564,7 @@ async fn test_daemon_embedding_multiple_new_jobs_with_failure() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -650,6 +658,7 @@ async fn test_daemon_embedding_jobs_streaming_with_failure() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -740,6 +749,7 @@ async fn test_daemon_job_labels_mismatch() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -784,6 +794,7 @@ async fn test_daemon_job_labels_match_insert() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -848,6 +859,7 @@ async fn test_daemon_job_label_update() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -929,6 +941,7 @@ async fn test_daemon_job_label_update_cancel() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -1009,6 +1022,7 @@ async fn test_daemon_job_labels_match() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -1063,6 +1077,7 @@ async fn test_daemon_embedding_init_job_streaming_large() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -1139,6 +1154,7 @@ async fn test_daemon_embedding_job_text_pk() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -1196,6 +1212,7 @@ async fn test_daemon_embedding_job_uuid_pk() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down
4 changes: 3 additions & 1 deletion lantern_cli/tests/daemon_index_test_with_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ async fn test_daemon_external_index_create_on_small_table() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -100,6 +101,7 @@ async fn test_daemon_external_index_create() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand Down Expand Up @@ -157,6 +159,7 @@ async fn test_daemon_external_index_wrong_ops() {
schema: "_lantern_extras_internal".to_owned(),
target_db: Some(vec![new_connection_uri]),
log_level: LogLevel::Debug,
data_path: None,
},
None,
cancel_token_clone,
Expand All @@ -177,4 +180,3 @@ async fn test_daemon_external_index_wrong_ops() {

cancel_token.cancel();
}

6 changes: 5 additions & 1 deletion lantern_extras/src/daemon.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use lantern_cli::{
daemon::{cli::DaemonArgs, start},
embeddings::core::{cohere_runtime::CohereRuntimeParams, openai_runtime::OpenAiRuntimeParams},
embeddings::core::{
cohere_runtime::CohereRuntimeParams, openai_runtime::OpenAiRuntimeParams,
ort_runtime::DATA_PATH,
},
logger::{LogLevel, Logger},
types::AnyhowVoidResult,
utils::{get_full_table_name, quote_ident},
Expand Down Expand Up @@ -88,6 +91,7 @@ pub fn start_daemon(
master_db_schema: String::new(),
schema: String::from("_lantern_extras_internal"),
target_db: Some(target_dbs.clone()),
data_path: Some(DATA_PATH.to_owned()),
},
Some(logger.clone()),
cancellation_token.clone(),
Expand Down

0 comments on commit 664f1a7

Please sign in to comment.