Skip to content

Commit

Permalink
rename to factory
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Oct 3, 2024
1 parent 823e889 commit 3854cd2
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 43 deletions.
17 changes: 13 additions & 4 deletions src/daft-local-execution/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::{
},
sources::in_memory::InMemorySource,
writes::{
partitioned_write::PartitionedWriteNode, physical_write::PhysicalWriteOperator,
partitioned_write::PartitionedWriteNode, physical_write::PhysicalWriterFactory,
unpartitioned_write::UnpartitionedWriteNode,
},
ExecutionRuntimeHandle, PipelineCreationSnafu,
Expand Down Expand Up @@ -376,20 +376,29 @@ pub fn physical_plan_to_pipeline(
target_chunk_size as f64

Check warning on line 376 in src/daft-local-execution/src/pipeline.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/pipeline.rs#L376

Added line #L376 was not covered by tests
} as usize,
);
let write_op = PhysicalWriteOperator::new(file_info.clone());
let write_factory = PhysicalWriterFactory::new(file_info.clone());
let name = match (&file_info.partition_cols.is_some(), &file_info.file_format) {
(true, FileFormat::Parquet) => "PartitionedParquetWrite",
(true, FileFormat::Csv) => "PartitionedCSVWrite",
(false, FileFormat::Parquet) => "UnpartitionedParquetWrite",
(false, FileFormat::Csv) => "UnpartitionedCSVWrite",
_ => unreachable!("Physical write should only support Parquet and CSV"),

Check warning on line 385 in src/daft-local-execution/src/pipeline.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/pipeline.rs#L385

Added line #L385 was not covered by tests
};
match &file_info.partition_cols {
Some(part_cols) => PartitionedWriteNode::new(
name,
child_node,
Arc::new(write_op),
Arc::new(write_factory),
part_cols.clone(),
target_file_rows,
target_chunk_rows,
file_schema.clone(),
)
.boxed(),
None => UnpartitionedWriteNode::new(
name,
child_node,
Arc::new(write_op),
Arc::new(write_factory),
target_file_rows,
target_chunk_rows,
file_schema.clone(),
Expand Down
3 changes: 1 addition & 2 deletions src/daft-local-execution/src/writes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ pub mod partitioned_write;
pub mod physical_write;
pub mod unpartitioned_write;

pub trait WriteOperator: Send + Sync {
fn name(&self) -> &'static str;
pub trait WriterFactory: Send + Sync {
fn create_writer(
&self,
file_idx: usize,
Expand Down
33 changes: 18 additions & 15 deletions src/daft-local-execution/src/writes/partitioned_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use daft_micropartition::{FileWriter, MicroPartition};
use daft_table::Table;
use snafu::ResultExt;

use super::WriteOperator;
use super::WriterFactory;
use crate::{
buffer::RowBasedBuffer,
channel::{create_channel, PipelineChannel, Receiver, Sender},
Expand All @@ -21,7 +21,7 @@ use crate::{

struct PerPartitionWriter {
writer: Box<dyn FileWriter>,
write_operator: Arc<dyn WriteOperator>,
writer_factory: Arc<dyn WriterFactory>,
partition_values: Table,
buffer: RowBasedBuffer,
target_file_rows: usize,
Expand All @@ -31,14 +31,14 @@ struct PerPartitionWriter {

impl PerPartitionWriter {
fn new(
write_operator: Arc<dyn WriteOperator>,
writer_factory: Arc<dyn WriterFactory>,
partition_values: Table,
target_file_rows: usize,
target_chunk_rows: usize,
) -> DaftResult<Self> {
Ok(Self {
writer: write_operator.create_writer(0, Some(&partition_values))?,
write_operator,
writer: writer_factory.create_writer(0, Some(&partition_values))?,
writer_factory,
partition_values,
buffer: RowBasedBuffer::new(target_chunk_rows),
target_file_rows,
Expand Down Expand Up @@ -70,7 +70,7 @@ impl PerPartitionWriter {
}
self.written_rows_so_far = 0;
self.writer = self
.write_operator
.writer_factory
.create_writer(self.results.len(), Some(&self.partition_values))?
}
Ok(())
Expand All @@ -92,9 +92,10 @@ impl PerPartitionWriter {
}

pub(crate) struct PartitionedWriteNode {
name: &'static str,
child: Box<dyn PipelineNode>,
runtime_stats: Arc<RuntimeStatsContext>,
write_operator: Arc<dyn WriteOperator>,
writer_factory: Arc<dyn WriterFactory>,
partition_cols: Vec<ExprRef>,
target_file_rows: usize,
target_chunk_rows: usize,
Expand All @@ -103,18 +104,20 @@ pub(crate) struct PartitionedWriteNode {

impl PartitionedWriteNode {
pub(crate) fn new(
name: &'static str,
child: Box<dyn PipelineNode>,
write_operator: Arc<dyn WriteOperator>,
writer_factory: Arc<dyn WriterFactory>,
partition_cols: Vec<ExprRef>,
target_file_rows: usize,
target_chunk_rows: usize,
file_schema: SchemaRef,
) -> Self {
Self {
name,
child,
runtime_stats: RuntimeStatsContext::new(),
partition_cols,
write_operator,
writer_factory,
target_file_rows,
target_chunk_rows,
file_schema,
Expand All @@ -137,7 +140,7 @@ impl PartitionedWriteNode {

async fn run_writer(
mut input_receiver: Receiver<(Table, Table)>,
write_operator: Arc<dyn WriteOperator>,
writer_factory: Arc<dyn WriterFactory>,
target_chunk_rows: usize,
target_file_rows: usize,
) -> DaftResult<Vec<Table>> {
Expand All @@ -149,7 +152,7 @@ impl PartitionedWriteNode {
per_partition_writers.insert(
partition_values_str.clone(),
PerPartitionWriter::new(
write_operator.clone(),
writer_factory.clone(),
partition_values,
target_file_rows,
target_chunk_rows,
Expand Down Expand Up @@ -182,7 +185,7 @@ impl PartitionedWriteNode {
fn spawn_writers(
num_writers: usize,
task_set: &mut tokio::task::JoinSet<DaftResult<Vec<Table>>>,
write_operator: Arc<dyn WriteOperator>,
writer_factory: Arc<dyn WriterFactory>,
target_chunk_rows: usize,
target_file_rows: usize,
) -> Vec<Sender<(Table, Table)>> {
Expand All @@ -191,7 +194,7 @@ impl PartitionedWriteNode {
let (writer_sender, writer_receiver) = create_channel(1);
task_set.spawn(Self::run_writer(
writer_receiver,
write_operator.clone(),
writer_factory.clone(),
target_chunk_rows,
target_file_rows,
));
Expand Down Expand Up @@ -252,7 +255,7 @@ impl PipelineNode for PartitionedWriteNode {
}

Check warning on line 255 in src/daft-local-execution/src/writes/partitioned_write.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/writes/partitioned_write.rs#L253-L255

Added lines #L253 - L255 were not covered by tests

fn name(&self) -> &'static str {
self.write_operator.name()
self.name
}

fn start(
Expand All @@ -276,7 +279,7 @@ impl PipelineNode for PartitionedWriteNode {
let writer_senders = Self::spawn_writers(
*NUM_CPUS,
&mut task_set,
self.write_operator.clone(),
self.writer_factory.clone(),
self.target_chunk_rows,
self.target_file_rows,
);
Expand Down
11 changes: 4 additions & 7 deletions src/daft-local-execution/src/writes/physical_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,19 @@ use daft_micropartition::{create_file_writer, FileWriter};
use daft_plan::OutputFileInfo;
use daft_table::Table;

use super::WriteOperator;
use super::WriterFactory;

pub(crate) struct PhysicalWriteOperator {
pub(crate) struct PhysicalWriterFactory {
output_file_info: OutputFileInfo,
}

impl PhysicalWriteOperator {
impl PhysicalWriterFactory {
pub(crate) fn new(output_file_info: OutputFileInfo) -> Self {
Self { output_file_info }
}
}

impl WriteOperator for PhysicalWriteOperator {
fn name(&self) -> &'static str {
"PhysicalWriteOperator"
}
impl WriterFactory for PhysicalWriterFactory {
fn create_writer(
&self,
file_idx: usize,
Expand Down
33 changes: 22 additions & 11 deletions src/daft-local-execution/src/writes/unpartitioned_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use daft_micropartition::{FileWriter, MicroPartition};
use daft_table::Table;
use snafu::ResultExt;

use super::WriteOperator;
use super::WriterFactory;
use crate::{
buffer::RowBasedBuffer,
channel::{create_channel, PipelineChannel, Receiver, Sender},
Expand All @@ -18,26 +18,29 @@ use crate::{
};

pub(crate) struct UnpartitionedWriteNode {
name: &'static str,
child: Box<dyn PipelineNode>,
runtime_stats: Arc<RuntimeStatsContext>,
write_operator: Arc<dyn WriteOperator>,
writer_factory: Arc<dyn WriterFactory>,
target_in_memory_file_rows: usize,
target_in_memory_chunk_rows: usize,
file_schema: SchemaRef,
}

impl UnpartitionedWriteNode {
pub(crate) fn new(
name: &'static str,
child: Box<dyn PipelineNode>,
write_operator: Arc<dyn WriteOperator>,
writer_factory: Arc<dyn WriterFactory>,
target_in_memory_file_rows: usize,
target_in_memory_chunk_rows: usize,
file_schema: SchemaRef,
) -> Self {
Self {
name,
child,
runtime_stats: RuntimeStatsContext::new(),
write_operator,
writer_factory,
target_in_memory_file_rows,
target_in_memory_chunk_rows,
file_schema,
Expand All @@ -48,9 +51,12 @@ impl UnpartitionedWriteNode {
Box::new(self)
}

// Receives data from the dispatcher and writes it.
// If the received file idx is different from the current file idx, this means that the current file is full and needs to be closed.
// Once input is exhausted, the current writer is closed and all written file paths are returned.
async fn run_writer(
mut input_receiver: Receiver<(Arc<MicroPartition>, usize)>,
write_operator: Arc<dyn WriteOperator>,
writer_factory: Arc<dyn WriterFactory>,
) -> DaftResult<Vec<Table>> {
let mut written_file_paths = vec![];
let mut current_writer: Option<Box<dyn FileWriter>> = None;
Expand All @@ -63,7 +69,7 @@ impl UnpartitionedWriteNode {
}

Check warning on line 69 in src/daft-local-execution/src/writes/unpartitioned_write.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/writes/unpartitioned_write.rs#L67-L69

Added lines #L67 - L69 were not covered by tests
}
current_file_idx = Some(file_idx);
current_writer = Some(write_operator.create_writer(file_idx, None)?);
current_writer = Some(writer_factory.create_writer(file_idx, None)?);
}

Check warning on line 73 in src/daft-local-execution/src/writes/unpartitioned_write.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/writes/unpartitioned_write.rs#L73

Added line #L73 was not covered by tests
if let Some(writer) = current_writer.as_mut() {
writer.write(&data)?;
Expand All @@ -80,18 +86,22 @@ impl UnpartitionedWriteNode {
fn spawn_writers(
num_writers: usize,
task_set: &mut TaskSet<DaftResult<Vec<Table>>>,
write_operator: &Arc<dyn WriteOperator>,
writer_factory: &Arc<dyn WriterFactory>,
channel_size: usize,
) -> Vec<Sender<(Arc<MicroPartition>, usize)>> {
let mut writer_senders = Vec::with_capacity(num_writers);
for _ in 0..num_writers {
let (writer_sender, writer_receiver) = create_channel(channel_size);
task_set.spawn(Self::run_writer(writer_receiver, write_operator.clone()));
task_set.spawn(Self::run_writer(writer_receiver, writer_factory.clone()));
writer_senders.push(writer_sender);
}
writer_senders
}

// Dispatches data received from the child to the writers
// As data is received, it is buffered until enough data is available to fill a chunk
// Once a chunk is filled, it is sent to a writer
// If the writer has written enough rows for a file, increment the file index and switch to the next writer
async fn dispatch(
mut input_receiver: CountingReceiver,
target_chunk_rows: usize,
Expand Down Expand Up @@ -157,7 +167,7 @@ impl PipelineNode for UnpartitionedWriteNode {
}

Check warning on line 167 in src/daft-local-execution/src/writes/unpartitioned_write.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/writes/unpartitioned_write.rs#L165-L167

Added lines #L165 - L167 were not covered by tests

fn name(&self) -> &'static str {
self.write_operator.name()
self.name
}

fn start(
Expand All @@ -177,12 +187,13 @@ impl PipelineNode for UnpartitionedWriteNode {
destination_channel.get_next_sender_with_stats(&self.runtime_stats);

// Start writers
let write_operator = self.write_operator.clone();
let writer_factory = self.writer_factory.clone();
let mut task_set = create_task_set();
let writer_senders = Self::spawn_writers(
*NUM_CPUS,
&mut task_set,
&write_operator,
&writer_factory,
// The channel size is set to the number of chunks per file such that writes can be parallelized
(self.target_in_memory_file_rows + self.target_in_memory_chunk_rows + 1)
/ self.target_in_memory_chunk_rows,
);
Expand Down
5 changes: 1 addition & 4 deletions tests/benchmarks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@ def gen_tpch(request):
num_parts = request.param

csv_files_location = data_generation.gen_csv_files(TPCH_DBGEN_DIR, num_parts, SCALE_FACTOR)

# Disable native executor to generate parquet files, remove once native executor supports writing parquet files
with daft.context.execution_config_ctx(enable_native_executor=False):
parquet_files_location = data_generation.gen_parquet(csv_files_location)
parquet_files_location = data_generation.gen_parquet(csv_files_location)

in_memory_tables = {}
for tbl_name in data_generation.SCHEMA.keys():
Expand Down

0 comments on commit 3854cd2

Please sign in to comment.