Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add defaults to arguments #11

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
36 changes: 26 additions & 10 deletions legendgram/legendgram.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,65 @@
from .util import make_location as _make_location
from .util import _get_cmap
import numpy as np

from matplotlib.colors import Colormap
import matplotlib.pyplot as plt
from palettable.palette import Palette


def legendgram(f, ax, y, breaks, pal, bins=50, clip=None,
def legendgram(y, breaks=None, pal=None, bins=50, clip=None,
loc = 'lower left', legend_size=(.27,.2),
frameon=False, tick_params = None):
frameon=False, tick_params = None, f=None, ax=None):
'''
Add a histogram in a choropleth with colors aligned with map
...

Arguments
---------
f : Figure
ax : AxesSubplot
y : ndarray/Series
Values to map
breaks : list
breaks : list or int
[Optional. Default=ten evenly-spaced percentiles from the 1st to the 99th]
Sequence with breaks for each class (i.e. boundary values
for colors)
pal : palettable colormap or matplotlib colormap
for colors). If an integer is supplied, this is used as the number of
evenly-spaced percentiles to use in the discretization
pal : palettable colormap, matplotlib colormap, or str
palette to use to construct the legendgram. (default: None)
clip : tuple
[Optional. Default=None] If a tuple, clips the X
axis of the histogram to the bounds provided.
loc : string or int
valid legend location like that used in matplotlib.pyplot.legend
loc : str or int
valid legend location like that used in matplotlib.pyplot.legend
legend_size : tuple
tuple of floats between 0 and 1 describing the (width,height)
of the legend relative to the original frame.
frameon : bool (default: False)
whether to add a frame to the legendgram
tick_params : keyword dictionary
options to control how the histogram axis gets ticked/labelled.
f : Figure
ax : AxesSubplot

Returns
-------
axis containing the legendgram.
'''
if f is None:
f = plt.gcf()
if ax is None:
ax = plt.gca()
if pal is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am missing something here, but why do you use None and an if clause for a default? You could just directly assign 'viridis' to the parameter in the api. This way users also see how the input out to look like?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious maybe I can learn something :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! I had intended to somehow detect the current matplotlib palette and use that if pal=None. But, I couldn't quite figure out how to do that. So, I deleted the exploration in this if branch & simply used pal = 'viridis'.

Ideally, you'd use pal=None and if pal is None, somehow check what colormap the user is using currently. I don't quite know how to do that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, yeah that sounds great!
do you think something like
'''
fig, ax = plt.subplots()
im = ax.imshow(a, cmap="viridis")

cm = im.get_cmap()
print(cm.name)
"""

through

'Matplotlib.cm.get_cmap()' could work?

Happy to leave this as is and deal with it later. Maybe add a note, so others can keep finding solutions? :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, well that's simole... how I missed it idk.

No reason to merge an incomplete if that works! I'll take a look tomorrow.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ljwolf Is this still under consideration? I put together a small gist demonstrating a helper function built on top of @slumnitz's idea. Basically, it uses the most recent colormap of an image, if the axis has any associated images. Otherwise, it defaults to viridis.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, totally! push on top of this so that it updates!

pal = _get_cmap(ax)
if breaks is None:
breaks = 10
if isinstance(breaks, int):
breaks = np.percentile(y, q=np.linspace(1,99,num=breaks))
k = len(breaks)
histpos = _make_location(ax, loc, legend_size=legend_size)
histax = f.add_axes(histpos)
N, bins, patches = histax.hist(y, bins=bins, color='0.1')
#---
if isinstance(pal, str):
pl = plt.get_cmap(pal)
if isinstance(pal, Palette):
assert k == pal.number, "provided number of classes does not match number of colors in palette."
pl = pal.get_mpl_colormap()
Expand Down
129 changes: 85 additions & 44 deletions legendgram/test_legendgram.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,39 @@
import unittest as ut
import pytest

import matplotlib as mpl
mpl.use("pdf")
import matplotlib.pyplot as plt
from pysal.contrib.viz.mapping import geoplot
import pysal as ps
from palettable import matplotlib as mplpal
from .util import inv_lut
from .legendgram import legendgram

import geopandas
import libpysal.examples as examples
from mapclassify import Quantiles
import numpy as np
from palettable import matplotlib as mplpal

class Test_Legendgram(ut.TestCase):
def setUp(self):
self.data = geopandas.read_file(ps.examples.get_path('south.shp'))
from .legendgram import legendgram
from .util import inv_lut


class Test_Legendgram:
def setup_method(self):
self.data = geopandas.read_file(examples.get_path('south.shp'))
self.test_attribute = 'HR70'
self.k = 10
self.breaks = ps.Quantiles(self.data[self.test_attribute].values, k=self.k).bins
self.breaks = Quantiles(self.data[self.test_attribute].values, k=self.k).bins
self.pal = mplpal.Inferno_10
self.cmap = mplpal.Inferno_10.get_mpl_colormap()

def genframe(self):
f,ax = plt.subplots()
def genframe(self, f=None, ax=None, cmap="inferno"):
if not f:
f, ax = plt.subplots()
self.data.plot(self.test_attribute, scheme='Quantiles',
k=self.k, cmap = mplpal.Inferno_10, ax=ax)
k=self.k, cmap=cmap, ax=ax)
return f,ax

def test_call(self):
f,ax = self.genframe()
aout = legendgram(f,ax, self.data[self.test_attribute].values,
self.breaks, mplpal.Inferno_10)
f, ax = self.genframe()
aout = legendgram(self.data[self.test_attribute].values,
breaks=self.breaks, pal=self.pal, f=f, ax=ax)
plt.close(f)

def test_positioning(self):
Expand All @@ -40,11 +45,11 @@ def test_positioning(self):
bboxes = []
for i in range(1,11):
f,ax = self.genframe()
aout = legendgram(f,ax, self.data[self.test_attribute].values,
self.breaks, mplpal.Inferno_10, loc=i)
aout = legendgram(self.data[self.test_attribute].values,
breaks=self.breaks, pal=self.pal, loc=i, f=f, ax=ax)
f2,ax2 = self.genframe()
aout2 = legendgram(f2,ax2, self.data[self.test_attribute].values,
self.breaks, mplpal.Inferno_10, loc=inv_lut[i])
aout2 = legendgram(self.data[self.test_attribute].values, loc=inv_lut[i],
breaks=self.breaks, pal=self.pal, f=f2, ax=ax2)
print(i,inv_lut[i])
bbox = aout.get_position()
bbox2 = aout2.get_position()
Expand All @@ -54,59 +59,95 @@ def test_positioning(self):
plt.close(f)
plt.close(f2)
for i in range(len(bboxes)-1):
self.assertTrue(bboxes[i].bounds != bboxes[i+1].bounds)
assert bboxes[i].bounds != bboxes[i+1].bounds
f,ax = self.genframe()
aout = legendgram(f,ax, self.data[self.test_attribute].values,
self.breaks, mplpal.Inferno_10, loc=0)
aout = legendgram(self.data[self.test_attribute].values,
breaks=self.breaks, pal=self.pal, loc=0, f=f, ax=ax)
bestbbox = aout.get_position()
print(bestbbox.bounds, bboxes[2].bounds)
np.testing.assert_allclose(bestbbox.bounds, bboxes[2].bounds) #best == bottom left

def test_tickparams(self):
f,ax = self.genframe()
aout = legendgram(f,ax, self.data[self.test_attribute].values,
self.breaks, mplpal.Inferno_10, tick_params=dict(labelsize=20))
aout = legendgram(self.data[self.test_attribute].values, breaks=self.breaks,
pal=self.pal, tick_params=dict(labelsize=20), f=f, ax=ax)
ticks = aout.get_xticklabels()
for tick in ticks:
self.assertEqual(tick.get_fontsize(), 20)
assert tick.get_fontsize() == 20
plt.close(f)

def test_frameon(self):
f,ax = self.genframe()
aout = legendgram(f,ax, self.data[self.test_attribute].values,
self.breaks, mplpal.Inferno_10, frameon=True)
self.assertTrue(aout.get_frame_on())
aout = legendgram(self.data[self.test_attribute].values,
breaks=self.breaks, pal=self.pal, frameon=True, f=f, ax=ax)
assert aout.get_frame_on()
plt.close(f)

f,ax = self.genframe()
aout = legendgram(f,ax, self.data[self.test_attribute].values,
self.breaks, mplpal.Inferno_10, frameon=False)
self.assertTrue(not aout.get_frame_on())
aout = legendgram(self.data[self.test_attribute].values,
breaks=self.breaks, pal=self.pal, frameon=False, f=f, ax=ax)
assert not aout.get_frame_on()
plt.close(f)

@ut.skip('Not sure how to test this')
@pytest.mark.skip('Not sure how to test this')
def test_sizing(self):
raise NotImplementedError('Not sure how to test this yet...')

@pytest.mark.skip('this should test that loc=[*subax_corner, *subax_dimension] passes through make_location unphased.')
def test_passthrough_sizing(self):
raise NotImplementedError('this should test that loc=[*subax_corner, *subax_dimension] passes through make_location unphased.')

def test_clip(self):
f,ax = self.genframe()
aout = legendgram(f,ax, self.data[self.test_attribute].values,
self.breaks, mplpal.Inferno_10, clip=(10,20))
self.assertEquals(aout.get_xlim(), (10,20))
aout = legendgram(self.data[self.test_attribute].values,
breaks=self.breaks, pal=self.pal, clip=(10,20), f=f, ax=ax)
assert aout.get_xlim() == (10,20)

def test_palettebreak_mismatch(self):
f,ax = self.genframe()
with self.assertRaises(AssertionError):
aout = legendgram(f,ax, self.data[self.test_attribute].values,
self.breaks, mplpal.Inferno_12)
with pytest.raises(AssertionError):
legendgram(self.data[self.test_attribute].values,
breaks=self.breaks, pal=mplpal.Inferno_12, f=f, ax=ax)

def test_matplotlib_cmap(self):
f,ax = self.genframe()
aout = legendgram(f,ax, self.data[self.test_attribute].values,
self.breaks, self.cmap)
aout = legendgram(self.data[self.test_attribute].values,
self.breaks, pal=self.cmap, f=f, ax=ax)
plt.close(f)

def test_pal_typeerror(self):
f,ax = self.genframe()
with self.assertRaises(ValueError):
aout = legendgram(f,ax, self.data[self.test_attribute].values,
self.breaks, 'pal')
with pytest.raises(ValueError):
legendgram(self.data[self.test_attribute].values,
breaks=self.breaks, pal='pal', f=f, ax=ax)

def test_no_previous_cmap_from_ax(self):
f, ax = plt.subplots()

with pytest.warns(
UserWarning, match="There is no data associated with the `ax`",
):
aout = legendgram(self.data[self.test_attribute].values,
self.breaks, pal=None, f=f, ax=ax)
plt.close(f)

def test_one_previous_cmap_from_ax(self):
f, ax = self.genframe()
aout = legendgram(self.data[self.test_attribute].values,
self.breaks, pal=None, f=f, ax=ax)
plt.close(f)

def test_two_previous_cmaps_from_ax(self):
f, ax = self.genframe()
f, ax = self.genframe(f=f, ax=ax, cmap="twilight")

with pytest.warns(
UserWarning,
match=(
"There are 2 unique colormaps associated with"
"the axes. Defaulting to last colormap: 'twilight'"
),
):
aout = legendgram(self.data[self.test_attribute].values,
self.breaks, pal=None, f=f, ax=ax)
plt.close(f)
39 changes: 39 additions & 0 deletions legendgram/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import warnings

from matplotlib import colormaps as cm
from matplotlib.axes._axes import Axes
from matplotlib.collections import Collection
from matplotlib.colors import ListedColormap

loc_lut = {'best' : 0,
'upper right' : 1,
Expand Down Expand Up @@ -34,6 +40,9 @@ def make_location(ax,loc, legend_size=(.27,.2)):


"""
if isinstance(loc, (list, tuple)):
assert len(loc) == 4
return loc
position = ax.get_position()
if isinstance(legend_size, float):
legend_size = (legend_size, legend_size)
Expand Down Expand Up @@ -68,3 +77,33 @@ def make_location(ax,loc, legend_size=(.27,.2)):
elif loc.lower() == 'upper right':
anchor_x, anchor_y = position.x0 + right_offset, position.y0 + top_offset
return [anchor_x, anchor_y, legend_width, legend_height]


def _get_cmap(_ax: Axes) -> ListedColormap:
"""Detect the most recent matplotlib colormap used, if previously rendered."""
_child_cmaps = [
(cc.cmap, cc.cmap.name) for cc
in _ax.properties()["children"]
if isinstance(cc, Collection)
]
has_child_cmaps = len(_child_cmaps)
n_unique_cmaps = len(set(cc[1] for cc in _child_cmaps))
if has_child_cmaps:
cmap, cmap_name = _child_cmaps[-1]
if n_unique_cmaps > 1:
warnings.warn(
(
f"There are {n_unique_cmaps} unique colormaps associated with"
f"the axes. Defaulting to last colormap: '{cmap_name}'"
),
UserWarning,
stacklevel=2
)
else:
warnings.warn(
"There is no data associated with the `ax`.",
UserWarning,
stacklevel=2
)
cmap = cm.get_cmap("viridis")
return cmap
File renamed without changes.