diff --git a/cmeutils/plotting.py b/cmeutils/plotting.py index 08f215c..e0232b7 100644 --- a/cmeutils/plotting.py +++ b/cmeutils/plotting.py @@ -1,6 +1,6 @@ +import matplotlib.pyplot as plt import numpy as np - def get_histogram(data, normalize=False, bins="auto"): """Bins a 1-D array of data into a histogram using the numpy.histogram method. @@ -31,3 +31,63 @@ def get_histogram(data, normalize=False, bins="auto"): bin_widths = np.diff(bin_borders) bin_centers = bin_borders[:-1] + bin_widths / 2 return bin_centers, bin_heights + +def threedplot( + x, + y, + z, + xlabel = "xlabel", + ylabel = "ylabel", + zlabel = "zlabel", + plot_name = "plot_name" + ): + + '''Plot a 3d heat map from 3 lists of numbers. This function is useful + for plotting a dependent variable as a function of two independent variables. + In the example below we use f(x,y)= -x^2 - y^2 +6 because it looks cool. + + Example + ------- + + We create two indepent variables and a dependent variable in the z axis and + plot the result. Here z is the equation of an elliptic paraboloid. + + import random + + x = [] + for i in range(0,1000): + n = random.uniform(-20,20) + x.append(n) + + y = [] + for i in range(0,1000): + n = random.uniform(-20,20) + y.append(n) + + z = [] + for i in range(0,len(x)): + z.append(-x[i]**2 - y[i]**2 +6) + + fig = threedplot(x,y,z) + fig.show() + + Parameters + ---------- + + x,y,z : list of int/floats + + xlabel, ylabel, zlabel : str + + plot_name : str + + + ''' + fig = plt.figure(figsize = (10, 10), facecolor = 'white') + ax = plt.axes(projection='3d') + ax.set_xlabel(xlabel,fontdict=dict(weight='bold'),fontsize=12) + ax.set_ylabel(ylabel,fontdict=dict(weight='bold'),fontsize=12) + ax.set_zlabel(zlabel,fontdict=dict(weight='bold'),fontsize=12) + p = ax.scatter(x, y, z, c=z, cmap='rainbow', linewidth=7); + plt.colorbar(p, pad = .1, aspect = 2.3) + + return fig diff --git a/cmeutils/tests/test_plotting.py b/cmeutils/tests/test_plotting.py index fdfe70d..0bc69a3 100644 --- a/cmeutils/tests/test_plotting.py +++ b/cmeutils/tests/test_plotting.py @@ -2,6 +2,8 @@ import pytest from cmeutils.plotting import get_histogram +from cmeutils.plotting import threedplot + from base_test import BaseTest @@ -16,3 +18,9 @@ def test_histogram_normalize(self): bin_c, bin_h = get_histogram(sample, normalize=True) assert all(bin_h <= 1) + def test_3dplot(self): + x = [1,2,3,4,5] + y = [1,2,3,4,5] + z = [1,2,3,4,5] + threedplot(x,y,z) +