diff --git a/apps/hannk/halide/common_halide.h b/apps/hannk/halide/common_halide.h index 82a9e22d408f..e499177a9410 100644 --- a/apps/hannk/halide/common_halide.h +++ b/apps/hannk/halide/common_halide.h @@ -39,6 +39,8 @@ Halide::Expr align(const Halide::Expr &x, const Halide::Expr &n); // where N is the number of bits of the narrowed result minus one. Halide::Expr multiply_2x_high(const Halide::Expr &a, const Halide::Expr &b); +// For a visualization of the approx_* functions and their errors, see: +// apps/hannk/halide/docs/approx_log2_and_applications.ipynb // Approximate log2(x/2^q_x)*2^q. // q must be less than 16. Halide::Expr approx_log2(int q, const Halide::Expr &x, int q_x, const Halide::Type &type = Halide::Int(32)); diff --git a/apps/hannk/halide/docs/approx_log2_and_applications.ipynb b/apps/hannk/halide/docs/approx_log2_and_applications.ipynb new file mode 100644 index 000000000000..d4771b5219b3 --- /dev/null +++ b/apps/hannk/halide/docs/approx_log2_and_applications.ipynb @@ -0,0 +1,382 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "code", + "metadata": { + "id": "r1XiiUQGUjpx" + }, + "source": [ + "import numpy as np\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Many architectures have shifts where the right-hand-side is signed. A negative\n", + "# RHS is the same as a positive shift in the other direction.\n", + "def shift_right(x, y):\n", + " return np.floor(x / 2**y)\n", + "def shift_left(x, y):\n", + " return np.floor(x * 2**y)\n", + "def rounding_shift_right(x, y):\n", + " return np.round(x / 2**y)\n", + "def rounding_shift_left(x, y):\n", + " return np.round(x * 2**y)\n", + "\n", + "def bitwise_and(x, y):\n", + " return np.mod(x, y + 1)\n", + "\n", + "# This is sqrdmulh on ARM\n", + "def multiply_2x_high(x, y):\n", + " return rounding_shift_right(x * y, 15)\n", + "\n", + "def relative_error(x, y):\n", + " return (x - y) / (np.maximum(x, y) + 1e-3)\n", + "\n", + "def plot_results(x, exact, approxs, title, logx = False, logy = False, relative = False, log2_xscale = 0, log2_yscale = 0):\n", + " fig, [p1, p2] = plt.subplots(2, 1)\n", + "\n", + " p1.set_xlabel('x')\n", + " if logx:\n", + " p1.set_xscale('log')\n", + " p1.set_ylabel(title)\n", + " if logy:\n", + " p1.set_yscale('log')\n", + "\n", + " xscale = 2**log2_xscale\n", + " yscale = 2**log2_yscale\n", + "\n", + " exact = np.round(exact*yscale)/yscale\n", + "\n", + " p1.plot(x/xscale, exact)\n", + " for approx in approxs:\n", + " p1.plot(x/xscale, approx/yscale)\n", + "\n", + " p2.set_xlabel('x')\n", + " if logx:\n", + " p2.set_xscale('log')\n", + "\n", + " p2.set_ylabel('relative error' if relative else 'error')\n", + " for approx in approxs:\n", + " p2.plot(x/xscale, relative_error(approx/yscale, exact) if relative else approx/yscale - exact)\n", + "\n", + "def eval_poly(x, p, q):\n", + " x1 = rounding_shift_left(x, 15 - q)\n", + " y = p[0]\n", + " xi = x1\n", + " for i in p[1:]:\n", + " y = y + multiply_2x_high(i, xi)\n", + " xi = multiply_2x_high(xi, x1)\n", + " return rounding_shift_right(y, 15 - q)\n", + "\n", + "points = 6\n", + "degree = 3\n", + "log2_poly_x = np.arange(points, 2 * points + 1) / points\n", + "log2_poly_y = np.log2(log2_poly_x)\n", + "log2_poly = np.polyfit(log2_poly_x - 1, log2_poly_y, degree)\n", + "\n", + "exp2_poly_x = np.arange(points, 2 * points + 1) / points\n", + "exp2_poly_y = np.exp2(exp2_poly_x - 1) - 1\n", + "exp2_poly = np.polyfit(exp2_poly_x - 1, exp2_poly_y, degree)\n", + "\n", + "log2_poly = log2_poly[::-1]\n", + "exp2_poly = exp2_poly[::-1]\n", + "\n", + "print(log2_poly)\n", + "print(exp2_poly)\n", + "\n", + "log2_poly = np.round(log2_poly * 2**15)\n", + "exp2_poly = np.round(exp2_poly * 2**15)\n", + "exp2_poly[0] = 0\n", + "\n", + "print(log2_poly)\n", + "print(exp2_poly)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "1xjo4hIEo_z5" + }, + "source": [ + "# Approximate N*log2(x*2^q_x), where N = 2^q, and the intermediate computations are\n", + "# restricted to be integers.\n", + "def approx_log2(x, q, q_x = 0):\n", + " # This can be computed with count_leading_zeros\n", + " floor_log2_x = np.select([x > 0], [np.floor(np.log2(x))], [-1])\n", + "\n", + " # We've computed log2(x*2^q_x) = log2(x) + q_x. Subtract that offset now\n", + " # before multiplying by the result quantization.\n", + " result = shift_left(floor_log2_x - q_x, q)\n", + "\n", + " frac = bitwise_and(shift_right(x, floor_log2_x - q), 2**q - 1)\n", + "\n", + " return result + eval_poly(frac, log2_poly, q)\n", + "\n", + "x = np.arange(1, 10000)\n", + "q = 15\n", + "q_x = 2\n", + "log2_x = np.log2(x / 2**q_x)\n", + "approx_log2_x = approx_log2(x, q, q_x)\n", + "\n", + "plot_results(x, log2_x, [approx_log2_x], 'log2(x)', logx=True, log2_xscale=q_x, log2_yscale=q)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "6uJN5muLsLdE" + }, + "source": [ + "\n", + "# Approximate 2^(x/2^q_x)*2^q\n", + "def approx_exp2(x, q_x, q):\n", + " int_part = shift_right(x, q_x)\n", + " frac_part = x - shift_left(int_part, q_x)\n", + "\n", + " frac_part = eval_poly(frac_part, exp2_poly, q_x)\n", + "\n", + " exp_int_part = shift_left(1, int_part + q)\n", + " return exp_int_part + rounding_shift_right(exp_int_part * frac_part, q_x)\n", + "\n", + "q_x = 10\n", + "q = 15\n", + "x = np.arange(-4000, 2000)\n", + "approx_exp2_x = approx_exp2(x, q_x, q)\n", + "exact = np.exp2(x / 2**q_x)\n", + "\n", + "plot_results(x, exact, [approx_exp2_x], '2^x', False, True, relative=True, log2_xscale=q_x, log2_yscale=q)\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "5BP-edzCmNBi" + }, + "source": [ + "q = 15\n", + "x = np.arange(10, 10000) * 10\n", + "round_trip_x = approx_exp2(approx_log2(x, q), q, 0)\n", + "\n", + "plot_results(x, x, [round_trip_x], '2^log2(x)', logx=True, logy=True, relative=True)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "nyrzI90uNH1s" + }, + "source": [ + "# Approximate 2^q*sqrt(2^(x/2^q_x))\n", + "def sqrt_approx_exp2(x, q_x, q):\n", + " return approx_exp2(x, q_x + 1, q)\n", + "\n", + "q = 11\n", + "q_x = 8\n", + "x = np.arange(-1000, 2000)\n", + "approx_exp2_x = sqrt_approx_exp2(x, q_x, q)\n", + "exact = np.sqrt(np.exp2(x / 2**q_x))\n", + "\n", + "plot_results(x, exact, [approx_exp2_x], 'sqrt(2^x)', relative=True, log2_xscale=q_x, log2_yscale=q)\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Kno5t4VihCTL" + }, + "source": [ + "# Approximate sqrt(x) = 2^((1/2)*log2(x))\n", + "def approx_sqrt(x, q):\n", + " # log2(x) will never be larger than 32, for 32-bit x. So to make the result\n", + " # fit in a 16-bit integer, we can make the precision 2^16/32 = 2048.\n", + " q_x = 11;\n", + "\n", + " log2_sqrt_x = approx_log2(x, q_x - 1)\n", + " return approx_exp2(log2_sqrt_x, q_x, q)\n", + "\n", + "q = 15\n", + "x = np.arange(1, 10000)**2\n", + "sqrt_x = np.sqrt(x)\n", + "approx_sqrt_x = approx_sqrt(x, q)\n", + "\n", + "plot_results(x, sqrt_x, [approx_sqrt_x], 'sqrt(x)', log2_yscale=q, relative=True)\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "0dMecIGr92WY" + }, + "source": [ + "# Approximate 2^31/sqrt(x) = 2^(-(1/2)*log2(x))\n", + "def approx_reciprocal_sqrt(x):\n", + " q = 15\n", + " log2_sqrt_x = approx_log2(x, q - 1)\n", + " return approx_exp2(-log2_sqrt_x, q, 31)\n", + "\n", + "x = np.arange(1, 10000)**2\n", + "inv_sqrt_x = 1 / np.sqrt(x)\n", + "approx_reciprocal_sqrt_x = approx_reciprocal_sqrt(x)\n", + "\n", + "plot_results(x, inv_sqrt_x, [approx_reciprocal_sqrt_x], '1/sqrt(x)', True, True, True, log2_yscale=31)\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "VFC9aUFcc8d7" + }, + "source": [ + "# Approximate 2^32/x = 2^32*2^(-log2(x))\n", + "def approx_reciprocal(x):\n", + " q = 15;\n", + " log2_x = approx_log2(x, q)\n", + " return approx_exp2(-log2_x, q, 31)\n", + "\n", + "x = 1.01**np.arange(0, 2000)\n", + "inv_x = 1 / x\n", + "approx_inv_x = approx_reciprocal(x)\n", + "# This is ~sqrt(2) times more accurate, but maybe not practical for large x.\n", + "approx_inv_sqrt_x2 = approx_reciprocal_sqrt(x*x)\n", + "\n", + "plot_results(x, inv_x, [approx_inv_x], '1/x', True, True, log2_yscale=31, relative=True)\n", + "plot_results(x, inv_x, [approx_inv_sqrt_x2], '1/x', True, True, log2_yscale=31, relative=True)\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "6BhQzLIZCcKC" + }, + "source": [ + "# Approximate log2(exp2(x) + c)\n", + "def approx_log2_exp2_plus_constant(x, c, q_x, q):\n", + " # When x/2^q_x is large, approx_exp2 below will overflow. But when it is large\n", + " # we don't need it to be very precise\n", + " q_exp = 16 #np.minimum(16, 16 - np.floor(np.log2(np.maximum(x, 1))))\n", + " one = 2**q_exp\n", + "\n", + " one_plus_exp2_x = one * c + approx_exp2(x, q_x, q_exp)\n", + " # Mimic overflow of int32\n", + " one_plus_exp2_x = np.mod(one_plus_exp2_x, 2**31)\n", + "\n", + " raw = approx_log2(one_plus_exp2_x, q, q_exp)\n", + "\n", + " line = rounding_shift_right(x, q_x - q)\n", + "\n", + " threshold = 30 - q_exp\n", + " result = np.select([shift_right(x, q_x) < threshold], [raw], line)\n", + " return result\n", + "\n", + "def approx_log2p1_exp2(x, q_x, q):\n", + " return approx_log2_exp2_plus_constant(x, 1, q_x, q)\n", + "\n", + "def approx_log2m1_exp2(x, q_x, q):\n", + " return approx_log2_exp2_plus_constant(x, -1, q_x, q)\n", + "\n", + "x = np.arange(-4000, 4000)*8\n", + "q_x = 11\n", + "q = 15\n", + "\n", + "exact = np.log2(np.exp2(x / 2**q_x) + 1)\n", + "approx = approx_log2p1_exp2(x, q_x, q)\n", + "plot_results(x, exact, [approx], 'log2(2^x + 1)', log2_xscale=q_x, log2_yscale=q)\n", + "\n", + "x = np.arange(1, 4000)*8\n", + "exact = np.log2(np.exp2(x / 2**q_x) - 1)\n", + "approx = approx_log2m1_exp2(x, q_x, q)\n", + "plot_results(x, exact, [approx], 'log2(2^x - 1)', log2_xscale=q_x, log2_yscale=q)\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "G6n1u8fcUf-3" + }, + "source": [ + "# Approximate logistic(x) = 1/(e^-x + 1)\n", + "# = 2^log2(1/(e^-x + 1))\n", + "# = 2^-log2(e^-x + 1)\n", + "def approx_logistic(x, q_x, q):\n", + " x2 = multiply_2x_high(x, np.round(-np.log2(np.exp(1)) * 2**14))\n", + " q_exp = 11\n", + " log2_d = approx_log2p1_exp2(x2, q_x - 1, q_exp)\n", + " return approx_exp2(-log2_d, q_exp, q)\n", + "\n", + "x = np.arange(-4000, 4000)*8\n", + "q_x = 11\n", + "q = 15\n", + "exact = 1 / (1 + np.exp(-x / 2**q_x))\n", + "approx = approx_logistic(x, q_x, q)\n", + "plot_results(x, exact, [approx], '1/(1 + e^-x)', log2_xscale=q_x, log2_yscale=q)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "LBXXNc_8twQD" + }, + "source": [ + "# Approximate tanh(x) = (e^2x - 1)/(e^2x + 1)\n", + "# = 2^log2((e^2x - 1)/(e^2x + 1))\n", + "# = 2^(log2(e^2x - 1) - log2(e^2x + 1))\n", + "def approx_tanh(x, q_x, q):\n", + " abs_x_base2 = multiply_2x_high(np.abs(x), np.round(np.log2(np.exp(1)) * 2**14))\n", + " q_exp = 11\n", + " log2_n = approx_log2m1_exp2(abs_x_base2, q_x - 2, q_exp)\n", + " log2_d = approx_log2p1_exp2(abs_x_base2, q_x - 2, q_exp)\n", + " # Saturate at int16\n", + " log2_n = np.clip(log2_n, -(2**15), 2**15)\n", + " log2_d = np.clip(log2_d, -(2**15), 2**15)\n", + " return np.sign(x) * approx_exp2(log2_n - log2_d, q_exp, q)\n", + "\n", + "x = np.arange(-4000, 4000)*8\n", + "q_x = 12\n", + "q = 15\n", + "exact = np.tanh(x / 2**q_x)\n", + "approx = approx_tanh(x, q_x, q)\n", + "\n", + "points = 20\n", + "poly_x = np.arange(0, points * 3) / points\n", + "poly_y = np.tanh(poly_x)\n", + "poly = np.polyfit(poly_x, poly_y, 6)\n", + "approx2 = np.polyval(poly, x / 2**q_x) * 2**q\n", + "\n", + "\n", + "plot_results(x, exact, [approx], 'tanh(x)', log2_xscale=q_x, log2_yscale=q)" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file