Skip to content

Commit

Permalink
fix: use datafusion apis to convert cid bytes to string
Browse files Browse the repository at this point in the history
  • Loading branch information
Samika Kashyap authored and Samika Kashyap committed Aug 27, 2024
1 parent ccce37f commit 4f4632a
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 55 deletions.
1 change: 1 addition & 0 deletions flight/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ hex.workspace = true
[dev-dependencies]
tokio = { workspace = true, features = ["macros", "rt"] }
test-log.workspace = true
datafusion.workspace = true

[package.metadata.cargo-machete]
ignored = [
Expand Down
13 changes: 11 additions & 2 deletions flight/src/conversion.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use crate::types::*;
use anyhow::Result;
use arrow::array::{
ArrayRef, BinaryBuilder, ListBuilder, PrimitiveBuilder, StringBuilder, StructArray,
UInt8Builder,
ArrayRef, BinaryBuilder, ListBuilder, PrimitiveBuilder, StringBuilder, StructArray, UInt64Builder, UInt8Builder
};
use arrow::datatypes::{DataType, Field};
use arrow::record_batch::RecordBatch;
Expand All @@ -17,6 +16,7 @@ pub struct ConclusionEventBuilder {
controller: StringBuilder,
data: BinaryBuilder,
previous: ListBuilder<BinaryBuilder>,
index: UInt64Builder
}

impl Default for ConclusionEventBuilder {
Expand All @@ -30,6 +30,7 @@ impl Default for ConclusionEventBuilder {
data: BinaryBuilder::new(),
previous: ListBuilder::new(BinaryBuilder::new())
.with_field(Field::new_list_field(DataType::Binary, false)),
index: UInt64Builder::new()
}
}
}
Expand All @@ -49,6 +50,7 @@ impl ConclusionEventBuilder {
self.previous.values().append_value(cid.to_bytes());
}
self.previous.append(!data_event.previous.is_empty());
self.index.append_value(data_event.index);
}
ConclusionEvent::Time(time_event) => {
self.stream_cid
Expand All @@ -61,6 +63,7 @@ impl ConclusionEventBuilder {
self.previous.values().append_value(cid.to_bytes());
}
self.previous.append(!time_event.previous.is_empty());
self.index.append_value(time_event.index);
}
}
}
Expand Down Expand Up @@ -92,6 +95,9 @@ impl ConclusionEventBuilder {
true,
));

let index = Arc::new(self.index.finish()) as ArrayRef;
let index_field = Arc::new(Field::new("index", DataType::UInt64, false));

StructArray::from(vec![
(event_type_field, event_type),
(stream_cid_field, stream_cid),
Expand All @@ -100,6 +106,7 @@ impl ConclusionEventBuilder {
(event_cid_field, event_cid),
(data_field, data),
(previous_field, previous),
(index_field, index)
])
}
}
Expand Down Expand Up @@ -139,6 +146,7 @@ impl<'a> Extend<&'a ConclusionEvent> for ConclusionEventBuilder {
/// fn main() -> Result<()> {
/// let events = vec![
/// ConclusionEvent::Data(ConclusionData {
/// index: 0,
/// event_cid: Cid::default(),
/// init: ConclusionInit {
/// stream_cid: Cid::default(),
Expand All @@ -150,6 +158,7 @@ impl<'a> Extend<&'a ConclusionEvent> for ConclusionEventBuilder {
/// data: vec![1, 2, 3],
/// }),
/// ConclusionEvent::Data(ConclusionData {
/// index: 1,
/// event_cid: Cid::default(),
/// init: ConclusionInit {
/// stream_cid: Cid::default(),
Expand Down
184 changes: 131 additions & 53 deletions flight/src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,71 @@
use super::*;
use crate::types::{ConclusionData, ConclusionEvent, ConclusionInit};
use arrow::{
array::{Array, BinaryArray, RecordBatch},
array::{Array, BinaryArray, RecordBatch, StringBuilder},
datatypes::DataType,
util::pretty::pretty_format_batches,
};
use ceramic_core::StreamIdType;
use cid::Cid;
use datafusion::{
common::{cast::as_binary_array, exec_datafusion_err},
execution::context::SessionContext,
functions_aggregate::expr_fn::array_agg,
logical_expr::{
col, expr::ScalarFunction, ColumnarValue, Expr, ScalarUDF, ScalarUDFImpl, Signature,
TypeSignature, Volatility,
},
};
use expect_test::expect;
use hex;
use std::str::FromStr;
use std::{any::Any, str::FromStr, sync::Arc};

fn convert_cids_to_string(record_batch: &RecordBatch) -> String {
let mut formatted = pretty_format_batches(&[record_batch.clone()])
.unwrap()
.to_string();
#[derive(Debug)]
pub struct CidString {
signature: Signature,
}

// Assuming that `stream_cid` and `event_cid` are the columns that need conversion
if let Some(array) = record_batch
.column(record_batch.schema().index_of("stream_cid").unwrap())
.as_any()
.downcast_ref::<BinaryArray>()
{
for i in 0..array.len() {
let cid_bytes = array.value(i);
let cid = Cid::try_from(cid_bytes).expect("Invalid CID");
let cid_str = cid.to_string();
formatted = formatted.replace(&hex::encode(cid_bytes), &cid_str);
impl CidString {
pub fn new() -> Self {
Self {
signature: Signature::new(
TypeSignature::Exact(vec![DataType::Binary]),
Volatility::Immutable,
),
}
}
}

if let Some(array) = record_batch
.column(record_batch.schema().index_of("event_cid").unwrap())
.as_any()
.downcast_ref::<BinaryArray>()
{
for i in 0..array.len() {
let cid_bytes = array.value(i);
let cid = Cid::try_from(cid_bytes).expect("Invalid CID");
let cid_str = cid.to_string();
formatted = formatted.replace(&hex::encode(cid_bytes), &cid_str);
impl ScalarUDFImpl for CidString {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"cid_string"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> datafusion::common::Result<DataType> {
Ok(DataType::Utf8)
}
fn invoke(&self, args: &[ColumnarValue]) -> datafusion::common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
let cids = as_binary_array(&args[0])?;
let mut strs = StringBuilder::new();
for cid in cids {
if let Some(cid) = cid {
strs.append_value(
Cid::read_bytes(cid)
.map_err(|err| exec_datafusion_err!("Error {err}"))?
.to_string(),
);
} else {
strs.append_null()
}
}
Ok(ColumnarValue::Array(Arc::new(strs.finish())))
}

formatted
}

/// Tests the conversion of ConclusionEvents to Arrow RecordBatch.
Expand All @@ -55,12 +78,12 @@ fn convert_cids_to_string(record_batch: &RecordBatch) -> String {
/// 1. The number of rows in the RecordBatch
/// 2. The schema of the RecordBatch
/// 3. The content of each column in the RecordBatch
#[test]
fn test_conclusion_events_to_record_batch() {
#[tokio::test]
async fn test_conclusion_events_to_record_batch() {
// Create mock ConclusionEvents
let events = vec![
ConclusionEvent::Data(ConclusionData {
event_cid: Cid::from_str("baeabeials2i6o2ppkj55kfbh7r2fzc73r2esohqfivekpag553lyc7f6bi")
event_cid: Cid::from_str("baeabeif2fdfqe2hu6ugmvgozkk3bbp5cqi4udp5rerjmz4pdgbzf3fvobu")
.unwrap(),
init: ConclusionInit {
stream_cid: Cid::from_str(
Expand All @@ -71,11 +94,9 @@ fn test_conclusion_events_to_record_batch() {
controller: "did:key:test1".to_string(),
dimensions: vec![],
},
previous: vec![Cid::from_str(
"baeabeif2fdfqe2hu6ugmvgozkk3bbp5cqi4udp5rerjmz4pdgbzf3fvobu",
)
.unwrap()],
previous: vec![],
data: vec![1, 2, 3],
index: 0,
}),
ConclusionEvent::Data(ConclusionData {
event_cid: Cid::from_str("baeabeid2w5pgdsdh25nah7batmhxanbj3x2w2is3atser7qxboyojv236q")
Expand All @@ -86,14 +107,15 @@ fn test_conclusion_events_to_record_batch() {
)
.unwrap(),
stream_type: StreamIdType::Model as u8,
controller: "did:key:test2".to_string(),
controller: "did:key:test1".to_string(),
dimensions: vec![],
},
previous: vec![Cid::from_str(
"baeabeials2i6o2ppkj55kfbh7r2fzc73r2esohqfivekpag553lyc7f6bi",
"baeabeif2fdfqe2hu6ugmvgozkk3bbp5cqi4udp5rerjmz4pdgbzf3fvobu",
)
.unwrap()],
data: vec![4, 5, 6],
index: 1,
}),
ConclusionEvent::Time(ConclusionTime {
event_cid: Cid::from_str("baeabeidtub3bnbojbickf6d4pqscaw6xpt5ksgido7kcsg2jyftaj237di")
Expand All @@ -108,9 +130,10 @@ fn test_conclusion_events_to_record_batch() {
)
.unwrap(),
stream_type: StreamIdType::Model as u8,
controller: "did:key:test3".to_string(),
controller: "did:key:test1".to_string(),
dimensions: vec![],
},
index: 2
}),
ConclusionEvent::Data(ConclusionData {
event_cid: Cid::from_str("baeabeiewqcj4bwhcssizv5kcyvsvm57bxghjpqshnbzkc6rijmwb4im4yq")
Expand All @@ -121,14 +144,17 @@ fn test_conclusion_events_to_record_batch() {
)
.unwrap(),
stream_type: StreamIdType::Model as u8,
controller: "did:key:test4".to_string(),
controller: "did:key:test1".to_string(),
dimensions: vec![],
},
previous: vec![Cid::from_str(
"baeabeidtub3bnbojbickf6d4pqscaw6xpt5ksgido7kcsg2jyftaj237di",
)
.unwrap()],
previous: vec![
Cid::from_str("baeabeidtub3bnbojbickf6d4pqscaw6xpt5ksgido7kcsg2jyftaj237di")
.unwrap(),
Cid::from_str("baeabeid2w5pgdsdh25nah7batmhxanbj3x2w2is3atser7qxboyojv236q")
.unwrap(),
],
data: vec![7, 8, 9],
index: 3
}),
];
// Convert events to RecordBatch
Expand All @@ -137,16 +163,68 @@ fn test_conclusion_events_to_record_batch() {
// Convert RecordBatch to string
// let formatted = pretty_format_batches(&[record_batch.clone()])
// .unwrap();
let formatted = convert_cids_to_string(&record_batch);
// let formatted = convert_cids_to_string(&record_batch);

let ctx = SessionContext::new();
ctx.register_batch("conclusion_feed", record_batch).unwrap();

let cid_string = Arc::new(ScalarUDF::from(CidString::new()));

let doc_state = ctx
.table("conclusion_feed")
.await
.unwrap()
.unnest_columns(&["previous"])
.unwrap()
.select(vec![
col("index"),
col("event_type"),
Expr::ScalarFunction(ScalarFunction::new_udf(
cid_string.clone(),
vec![col("stream_cid")],
))
.alias("stream_cid"),
col("stream_type"),
col("controller"),
Expr::ScalarFunction(ScalarFunction::new_udf(
cid_string.clone(),
vec![col("event_cid")],
))
.alias("event_cid"),
col("data"),
Expr::ScalarFunction(ScalarFunction::new_udf(cid_string, vec![col("previous")]))
.alias("previous"),
])
.unwrap()
.aggregate(
vec![
col("index"),
col("event_type"),
col("stream_cid"),
col("stream_type"),
col("controller"),
col("event_cid"),
col("data"),
],
vec![array_agg(col("previous")).alias("previous")],
)
.unwrap()
.sort(vec![col("index").sort(true, true)])
.unwrap()
.collect()
.await
.unwrap();

let formatted = pretty_format_batches(&doc_state).unwrap().to_string();

// Use expect_test to validate the output
expect![[r#"
+------------+--------------------------------------------------------------------------+-------------+---------------+--------------------------------------------------------------------------+--------+----------------------------------------------------------------------------+
| event_type | stream_cid | stream_type | controller | event_cid | data | previous |
+------------+--------------------------------------------------------------------------+-------------+---------------+--------------------------------------------------------------------------+--------+----------------------------------------------------------------------------+
| 0 | baeabeif2fdfqe2hu6ugmvgozkk3bbp5cqi4udp5rerjmz4pdgbzf3fvobu | 2 | did:key:test1 | baeabeials2i6o2ppkj55kfbh7r2fzc73r2esohqfivekpag553lyc7f6bi | 010203 | [baeabeif2fdfqe2hu6ugmvgozkk3bbp5cqi4udp5rerjmz4pdgbzf3fvobu] |
| 0 | baeabeif2fdfqe2hu6ugmvgozkk3bbp5cqi4udp5rerjmz4pdgbzf3fvobu | 2 | did:key:test2 | baeabeid2w5pgdsdh25nah7batmhxanbj3x2w2is3atser7qxboyojv236q | 040506 | [baeabeials2i6o2ppkj55kfbh7r2fzc73r2esohqfivekpag553lyc7f6bi] |
| 1 | baeabeif2fdfqe2hu6ugmvgozkk3bbp5cqi4udp5rerjmz4pdgbzf3fvobu | 2 | did:key:test3 | baeabeidtub3bnbojbickf6d4pqscaw6xpt5ksgido7kcsg2jyftaj237di | | [baeabeid2w5pgdsdh25nah7batmhxanbj3x2w2is3atser7qxboyojv236q] |
| 0 | baeabeif2fdfqe2hu6ugmvgozkk3bbp5cqi4udp5rerjmz4pdgbzf3fvobu | 2 | did:key:test4 | baeabeiewqcj4bwhcssizv5kcyvsvm57bxghjpqshnbzkc6rijmwb4im4yq | 070809 | [baeabeidtub3bnbojbickf6d4pqscaw6xpt5ksgido7kcsg2jyftaj237di] |
+------------+--------------------------------------------------------------------------+-------------+---------------+--------------------------------------------------------------------------+--------+----------------------------------------------------------------------------+"#]].assert_eq(&formatted);
+-------+------------+-------------------------------------------------------------+-------------+---------------+-------------------------------------------------------------+--------+----------------------------------------------------------------------------------------------------------------------------+
| index | event_type | stream_cid | stream_type | controller | event_cid | data | previous |
+-------+------------+-------------------------------------------------------------+-------------+---------------+-------------------------------------------------------------+--------+----------------------------------------------------------------------------------------------------------------------------+
| 0 | 0 | baeabeif2fdfqe2hu6ugmvgozkk3bbp5cqi4udp5rerjmz4pdgbzf3fvobu | 2 | did:key:test1 | baeabeif2fdfqe2hu6ugmvgozkk3bbp5cqi4udp5rerjmz4pdgbzf3fvobu | 010203 | [] |
| 1 | 0 | baeabeif2fdfqe2hu6ugmvgozkk3bbp5cqi4udp5rerjmz4pdgbzf3fvobu | 2 | did:key:test1 | baeabeid2w5pgdsdh25nah7batmhxanbj3x2w2is3atser7qxboyojv236q | 040506 | [baeabeif2fdfqe2hu6ugmvgozkk3bbp5cqi4udp5rerjmz4pdgbzf3fvobu] |
| 2 | 1 | baeabeif2fdfqe2hu6ugmvgozkk3bbp5cqi4udp5rerjmz4pdgbzf3fvobu | 2 | did:key:test1 | baeabeidtub3bnbojbickf6d4pqscaw6xpt5ksgido7kcsg2jyftaj237di | | [baeabeid2w5pgdsdh25nah7batmhxanbj3x2w2is3atser7qxboyojv236q] |
| 3 | 0 | baeabeif2fdfqe2hu6ugmvgozkk3bbp5cqi4udp5rerjmz4pdgbzf3fvobu | 2 | did:key:test1 | baeabeiewqcj4bwhcssizv5kcyvsvm57bxghjpqshnbzkc6rijmwb4im4yq | 070809 | [baeabeidtub3bnbojbickf6d4pqscaw6xpt5ksgido7kcsg2jyftaj237di, baeabeid2w5pgdsdh25nah7batmhxanbj3x2w2is3atser7qxboyojv236q] |
+-------+------------+-------------------------------------------------------------+-------------+---------------+-------------------------------------------------------------+--------+----------------------------------------------------------------------------------------------------------------------------+"#]].assert_eq(&formatted);
}
2 changes: 2 additions & 0 deletions flight/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub struct ConclusionData {
// TODO : rethink this, add a checkpoint to make it work in datafusion query
pub previous: Vec<Cid>,
pub data: Vec<u8>,
pub index: u64
}

pub type CeramicTime = chrono::DateTime<chrono::Utc>;
Expand All @@ -48,4 +49,5 @@ pub struct ConclusionTime {
pub event_cid: Cid,
pub init: ConclusionInit,
pub previous: Vec<Cid>,
pub index: u64
}

0 comments on commit 4f4632a

Please sign in to comment.