diff --git a/docs/difference-in-differences.qmd b/docs/difference-in-differences.qmd index 2d7b7ac4..e9775262 100644 --- a/docs/difference-in-differences.qmd +++ b/docs/difference-in-differences.qmd @@ -61,7 +61,7 @@ pf.panelview( ylab="Cohort", xlab="Year", title="Treatment Assignment Cohorts", - figsize=(0.5, 0.5), + figsize=(6, 5), ) ``` @@ -78,7 +78,7 @@ pf.panelview( ylab="Cohort", xlab="Year", title="Treatment Assignment Cohorts", - figsize=(0.5, 0.5), + figsize=(6, 5), ) ``` @@ -98,7 +98,7 @@ pf.panelview( ylab="Unit", xlab="Year", title="Treatment Assignment (all units)", - figsize=(0.5, 0.5), + figsize=(6, 5), ) ``` @@ -118,8 +118,9 @@ pf.panelview( time="year", treat="treat", collapse_to_cohort=True, - title = "Outcome Plot", - figsize=(2, 0.75), + title="Outcome Plot", + legend=True, + figsize=(7, 2.5), ) ``` @@ -143,7 +144,8 @@ pf.panelview( treat="treat", subsamp=100, title = "Outcome Plot", - figsize=(2, 0.75), + legend=True, + figsize=(7, 2.5), ) ``` diff --git a/pyfixest/did/visualize.py b/pyfixest/did/visualize.py index 9270abbc..8371a6d3 100644 --- a/pyfixest/did/visualize.py +++ b/pyfixest/did/visualize.py @@ -135,6 +135,7 @@ def panelview( time=time, treat=treat, outcome=outcome, + collapse_to_cohort=collapse_to_cohort, ax=ax, xlab=xlab, ylab=ylab, @@ -161,6 +162,7 @@ def panelview( ax=ax, xlab=xlab, ylab=ylab, + figsize=figsize, legend=legend, noticks=noticks, title=title, @@ -211,6 +213,7 @@ def get_treatment_start(x: pd.DataFrame) -> pd.Timestamp: ) data_agg = data_agg.rename(columns={"treatment_start": unit}) + data_agg[unit] = data_agg[unit].fillna("no_treatment") data = data_agg.copy() data_pivot = data_agg.pivot(index=unit, columns=time, values=outcome) @@ -224,6 +227,7 @@ def _plot_panelview_output_plot( time: str, treat: str, outcome: str, + collapse_to_cohort: Optional[bool] = None, ax: Optional[plt.Axes] = None, xlab: Optional[str] = None, ylab: Optional[str] = None, @@ -234,10 +238,13 @@ def _plot_panelview_output_plot( figsize: Optional[tuple] = (11, 3), ) -> plt.Axes: if not ax: - f, ax = plt.subplots(figsize=figsize, dpi=300) + f, ax = plt.subplots(figsize=figsize) for unit_id in data_pivot.index: unit_data = data_pivot.loc[unit_id] - treatment_times = data[(data[unit] == unit_id) & (data[treat])][time] + if collapse_to_cohort: + treatment_times = data[(data[time] == unit_id) & (data[treat])][time] + else: + treatment_times = data[(data[unit] == unit_id) & (data[treat])][time] # If the unit never receives treatment, plot the line in grey if treatment_times.empty: @@ -323,12 +330,13 @@ def _plot_panelview( ax: Optional[plt.Axes] = None, xlab: Optional[str] = None, ylab: Optional[str] = None, + figsize: Optional[tuple] = (11, 3), legend: Optional[bool] = False, noticks: Optional[bool] = False, title: Optional[str] = None, ) -> plt.Axes: if not ax: - f, ax = plt.subplots() + f, ax = plt.subplots(figsize=figsize) cax = ax.matshow(treatment_quilt, cmap="viridis", aspect="auto") f.colorbar(cax) if legend else None ax.set_xlabel(xlab) if xlab else None