From ed7016a1271a2cd480437a79f32e122dc96fd1ca Mon Sep 17 00:00:00 2001 From: Even Rouault Date: Sat, 31 Aug 2024 18:14:53 +0200 Subject: [PATCH] SQLite SQL dialect: add MEDIAN, PERCENTILE, PERCENTILE_CONT and MODE ordered-set aggregate functions --- autotest/ogr/ogr_sqlite.py | 117 +++++++ doc/source/user/sql_sqlite_dialect.rst | 14 +- .../sqlite/ogrsqlitesqlfunctionscommon.cpp | 290 ++++++++++++++++++ 3 files changed, 419 insertions(+), 2 deletions(-) diff --git a/autotest/ogr/ogr_sqlite.py b/autotest/ogr/ogr_sqlite.py index 621e5c990fe9..6ae3f9fddcdc 100755 --- a/autotest/ogr/ogr_sqlite.py +++ b/autotest/ogr/ogr_sqlite.py @@ -4106,6 +4106,123 @@ def test_ogr_sqlite_stddev(): assert f.GetField(1) == pytest.approx(0.5**0.5, rel=1e-15) +@pytest.mark.parametrize( + "input_values,expected_res", + [ + ([], None), + ([1], 1), + ([2.5, None, 1], 1.75), + ([3, 2.2, 1], 2.2), + ([1, "invalid"], None), + ], +) +def test_ogr_sqlite_median(input_values, expected_res): + """Test MEDIAN""" + + ds = ogr.Open(":memory:", update=1) + ds.ExecuteSQL("CREATE TABLE test(v)") + for v in input_values: + ds.ExecuteSQL( + "INSERT INTO test VALUES (%s)" + % ( + "NULL" + if v is None + else ("'" + v + "'") + if isinstance(v, str) + else str(v) + ) + ) + if expected_res is None and input_values: + with pytest.raises(Exception), gdaltest.error_handler(): + with ds.ExecuteSQL("SELECT MEDIAN(v) FROM test"): + pass + else: + with ds.ExecuteSQL("SELECT MEDIAN(v) FROM test") as sql_lyr: + f = sql_lyr.GetNextFeature() + assert f.GetField(0) == pytest.approx(expected_res) + with ds.ExecuteSQL("SELECT PERCENTILE(v, 50) FROM test") as sql_lyr: + f = sql_lyr.GetNextFeature() + assert f.GetField(0) == pytest.approx(expected_res) + with ds.ExecuteSQL("SELECT PERCENTILE_CONT(v, 0.5) FROM test") as sql_lyr: + f = sql_lyr.GetNextFeature() + assert f.GetField(0) == pytest.approx(expected_res) + + +def test_ogr_sqlite_percentile(): + """Test PERCENTILE""" + + ds = ogr.Open(":memory:", update=1) + ds.ExecuteSQL("CREATE TABLE test(v)") + ds.ExecuteSQL("INSERT INTO test VALUES (5),(6),(4),(7),(3),(8),(2),(9),(1),(10)") + + with pytest.raises(Exception), gdaltest.error_handler(): + with ds.ExecuteSQL("SELECT PERCENTILE(v, 'invalid') FROM test"): + pass + with pytest.raises(Exception), gdaltest.error_handler(): + with ds.ExecuteSQL("SELECT PERCENTILE(v, -0.1) FROM test"): + pass + with pytest.raises(Exception), gdaltest.error_handler(): + with ds.ExecuteSQL("SELECT PERCENTILE(v, 100.1) FROM test"): + pass + with pytest.raises(Exception), gdaltest.error_handler(): + with ds.ExecuteSQL("SELECT PERCENTILE(v, v) FROM test"): + pass + + +def test_ogr_sqlite_percentile_cont(): + """Test PERCENTILE_CONT""" + + ds = ogr.Open(":memory:", update=1) + ds.ExecuteSQL("CREATE TABLE test(v)") + ds.ExecuteSQL("INSERT INTO test VALUES (5),(6),(4),(7),(3),(8),(2),(9),(1),(10)") + + with pytest.raises(Exception), gdaltest.error_handler(): + with ds.ExecuteSQL("SELECT PERCENTILE_CONT(v, 'invalid') FROM test"): + pass + with pytest.raises(Exception), gdaltest.error_handler(): + with ds.ExecuteSQL("SELECT PERCENTILE_CONT(v, -0.1) FROM test"): + pass + with pytest.raises(Exception), gdaltest.error_handler(): + with ds.ExecuteSQL("SELECT PERCENTILE_CONT(v, 1.1) FROM test"): + pass + + +@pytest.mark.parametrize( + "input_values,expected_res", + [ + ([], None), + ([1, 2, None, 3, 2], 2), + (["foo", "bar", "baz", "bar"], "bar"), + ([1, "foo", 2, "foo", "bar"], "foo"), + ([1, "foo", 2, "foo", 1], "foo"), + ], +) +def test_ogr_sqlite_mode(input_values, expected_res): + """Test MODE""" + + ds = ogr.Open(":memory:", update=1) + ds.ExecuteSQL("CREATE TABLE test(v)") + for v in input_values: + ds.ExecuteSQL( + "INSERT INTO test VALUES (%s)" + % ( + "NULL" + if v is None + else ("'" + v + "'") + if isinstance(v, str) + else str(v) + ) + ) + if expected_res is None and input_values: + with pytest.raises(Exception), gdaltest.error_handler(): + with ds.ExecuteSQL("SELECT MODE(v) FROM test"): + pass + else: + with ds.ExecuteSQL("SELECT MODE(v) FROM test") as sql_lyr: + f = sql_lyr.GetNextFeature() + assert f.GetField(0) == expected_res + + def test_ogr_sqlite_run_deferred_actions_before_start_transaction(): ds = ogr.Open(":memory:", update=1) diff --git a/doc/source/user/sql_sqlite_dialect.rst b/doc/source/user/sql_sqlite_dialect.rst index 0a539e115d26..0bd4c97e52e2 100644 --- a/doc/source/user/sql_sqlite_dialect.rst +++ b/doc/source/user/sql_sqlite_dialect.rst @@ -208,8 +208,18 @@ Statistics functions In addition to standard COUNT(), SUM(), AVG(), MIN(), MAX(), the following aggregate functions are available: -- STDDEV_POP: (GDAL >= 3.10) numerical population standard deviation. -- STDDEV_SAMP: (GDAL >= 3.10) numerical `sample standard deviation `__ +- ``STDDEV_POP(numeric_value)``: (GDAL >= 3.10) numerical population standard deviation. +- ``STDDEV_SAMP(numeric_value)``: (GDAL >= 3.10) numerical `sample standard deviation `__ + +Ordered-set aggregate functions ++++++++++++++++++++++++++++++++ + +The following aggregate functions are available. Note that they require to allocate an amount of memory proportionnal to the number of selected rows (for ``MEDIAN``, ``PERCENTILE`` and ``PERCENTILE_CONT``) or to the number of values (for ``MODE``). + +- ``MEDIAN(numeric_value)``: (GDAL >= 3.10) (continuous) median (equivalent to ``PERCENTILE(numeric_value, 50)``). NULL values are ignored. +- ``PERCENTILE(numeric_value, percentage)``: (GDAL >= 3.10) (continuous) percentile, with percentage between 0 and 100 (equivalent to ``PERCENTILE_CONT(numeric_value, percentage / 100)``). NULL values are ignored. +- ``PERCENTILE_CONT(numeric_value, fraction)``: (GDAL >= 3.10) (continuous) percentile, with fraction between 0 and 1. NULL values are ignored. +- ``MODE(value)``: (GDAL >= 3.10): mode, i.e. most frequent input value (strings and numeric values are supported), arbitrarily choosing the first one if there are multiple equally-frequent results. NULL values are ignored. Spatialite SQL functions ++++++++++++++++++++++++ diff --git a/ogr/ogrsf_frmts/sqlite/ogrsqlitesqlfunctionscommon.cpp b/ogr/ogrsf_frmts/sqlite/ogrsqlitesqlfunctionscommon.cpp index da3f35fd4b51..bf1ca9dd1b19 100644 --- a/ogr/ogrsf_frmts/sqlite/ogrsqlitesqlfunctionscommon.cpp +++ b/ogr/ogrsf_frmts/sqlite/ogrsqlitesqlfunctionscommon.cpp @@ -37,6 +37,9 @@ #include "ogrsqliteregexp.cpp" /* yes the .cpp file, to make it work on Windows with load_extension('gdalXX.dll') */ +#include +#include +#include #include #include "ogr_swq.h" @@ -321,6 +324,277 @@ static void OGRSQLITE_STDDEV_SAMP_Finalize(sqlite3_context *pContext) } } +/************************************************************************/ +/* OGRSQLITE_Percentile_Step() */ +/************************************************************************/ + +// Percentile related code inspired from https://sqlite.org/src/file/ext/misc/percentile.c +// of https://www.sqlite.org/draft/percentile.html + +// Constant addd to Percentile::rPct, since rPct is initialized to 0 when unset. +constexpr double PERCENT_ADD_CONSTANT = 1; + +namespace +{ +struct Percentile +{ + double rPct; /* PERCENT_ADD_CONSTANT more than the value for P */ + std::vector *values; /* Array of Y values */ +}; +} // namespace + +/* +** The "step" function for percentile(Y,P) is called once for each +** input row. +*/ +static void OGRSQLITE_Percentile_Step(sqlite3_context *pCtx, int argc, + sqlite3_value **argv) +{ + assert(argc == 2 || argc == 1); + + double rPct; + + if (argc == 1) + { + /* Requirement 13: median(Y) is the same as percentile(Y,50). */ + rPct = 50.0; + } + else if (sqlite3_user_data(pCtx) == nullptr) + { + /* Requirement 3: P must be a number between 0 and 100 */ + const int eType = sqlite3_value_numeric_type(argv[1]); + rPct = sqlite3_value_double(argv[1]); + if ((eType != SQLITE_INTEGER && eType != SQLITE_FLOAT) || rPct < 0.0 || + rPct > 100.0) + { + sqlite3_result_error(pCtx, + "2nd argument to percentile() is not " + "a number between 0.0 and 100.0", + -1); + return; + } + } + else + { + /* Requirement 3: P must be a number between 0 and 1 */ + const int eType = sqlite3_value_numeric_type(argv[1]); + rPct = sqlite3_value_double(argv[1]); + if ((eType != SQLITE_INTEGER && eType != SQLITE_FLOAT) || rPct < 0.0 || + rPct > 1.0) + { + sqlite3_result_error(pCtx, + "2nd argument to percentile_cont() is not " + "a number between 0.0 and 1.0", + -1); + return; + } + rPct *= 100.0; + } + + /* Allocate the session context. */ + auto p = static_cast( + sqlite3_aggregate_context(pCtx, sizeof(Percentile))); + if (!p) + return; + + /* Remember the P value. Throw an error if the P value is different + ** from any prior row, per Requirement (2). */ + if (p->rPct == 0.0) + { + p->rPct = rPct + PERCENT_ADD_CONSTANT; + } + else if (p->rPct != rPct + PERCENT_ADD_CONSTANT) + { + sqlite3_result_error(pCtx, + "2nd argument to percentile() is not the " + "same for all input rows", + -1); + return; + } + + /* Ignore rows for which the value is NULL */ + const int eType = sqlite3_value_type(argv[0]); + if (eType == SQLITE_NULL) + return; + + /* If not NULL, then Y must be numeric. Otherwise throw an error. + ** Requirement 4 */ + if (eType != SQLITE_INTEGER && eType != SQLITE_FLOAT) + { + sqlite3_result_error(pCtx, + "1st argument to percentile() is not " + "numeric", + -1); + return; + } + + /* Ignore rows for which the value is NaN */ + const double v = sqlite3_value_double(argv[0]); + if (std::isnan(v)) + { + return; + } + + if (!p->values) + p->values = new std::vector(); + try + { + p->values->push_back(v); + } + catch (const std::exception &) + { + delete p->values; + memset(p, 0, sizeof(*p)); + sqlite3_result_error_nomem(pCtx); + return; + } +} + +/************************************************************************/ +/* OGRSQLITE_Percentile_Finalize() */ +/************************************************************************/ + +/* +** Called to compute the final output of percentile() and to clean +** up all allocated memory. +*/ +static void OGRSQLITE_Percentile_Finalize(sqlite3_context *pCtx) +{ + auto p = static_cast(sqlite3_aggregate_context(pCtx, 0)); + if (!p) + return; + if (!p->values) + return; + if (!p->values->empty()) + { + std::sort(p->values->begin(), p->values->end()); + const double ix = (p->rPct - PERCENT_ADD_CONSTANT) * + static_cast(p->values->size() - 1) * 0.01; + const size_t i1 = static_cast(ix); + const size_t i2 = + ix == static_cast(i1) || i1 == p->values->size() - 1 + ? i1 + : i1 + 1; + const double v1 = (*p->values)[i1]; + const double v2 = (*p->values)[i2]; + const double vx = v1 + (v2 - v1) * static_cast(ix - i1); + sqlite3_result_double(pCtx, vx); + } + delete p->values; + memset(p, 0, sizeof(*p)); +} + +/************************************************************************/ +/* OGRSQLITE_Mode_Step() */ +/************************************************************************/ + +namespace +{ +struct Mode +{ + std::map *numericValues; + std::map *stringValues; + double mostFrequentNumValue; + std::string *mostFrequentStr; + uint64_t mostFrequentValueCount; + bool mostFrequentValueIsStr; +}; +} // namespace + +static void OGRSQLITE_Mode_Step(sqlite3_context *pCtx, int /*argc*/, + sqlite3_value **argv) +{ + const int eType = sqlite3_value_type(argv[0]); + if (eType == SQLITE_NULL) + return; + + if (eType == SQLITE_BLOB) + { + sqlite3_result_error(pCtx, "BLOB argument not supported for mode()", + -1); + return; + } + + /* Allocate the session context. */ + auto p = static_cast(sqlite3_aggregate_context(pCtx, sizeof(Mode))); + if (!p) + return; + + try + { + if (eType == SQLITE_TEXT) + { + const char *pszStr = + reinterpret_cast(sqlite3_value_text(argv[0])); + if (!p->stringValues) + { + p->stringValues = new std::map(); + p->mostFrequentStr = new std::string(); + } + const uint64_t count = ++(*p->stringValues)[pszStr]; + if (count > p->mostFrequentValueCount) + { + p->mostFrequentValueCount = count; + p->mostFrequentValueIsStr = true; + *(p->mostFrequentStr) = pszStr; + } + } + else + { + const double v = sqlite3_value_double(argv[0]); + if (std::isnan(v)) + return; + if (!p->numericValues) + p->numericValues = new std::map(); + const uint64_t count = ++(*p->numericValues)[v]; + if (count > p->mostFrequentValueCount) + { + p->mostFrequentValueCount = count; + p->mostFrequentValueIsStr = false; + p->mostFrequentNumValue = v; + } + } + } + catch (const std::exception &) + { + delete p->stringValues; + delete p->numericValues; + delete p->mostFrequentStr; + memset(p, 0, sizeof(*p)); + sqlite3_result_error_nomem(pCtx); + return; + } +} + +/************************************************************************/ +/* OGRSQLITE_Mode_Finalize() */ +/************************************************************************/ + +static void OGRSQLITE_Mode_Finalize(sqlite3_context *pCtx) +{ + auto p = static_cast(sqlite3_aggregate_context(pCtx, 0)); + if (!p) + return; + + if (p->mostFrequentValueCount) + { + if (p->mostFrequentValueIsStr) + { + sqlite3_result_text(pCtx, p->mostFrequentStr->c_str(), -1, + SQLITE_TRANSIENT); + } + else + { + sqlite3_result_double(pCtx, p->mostFrequentNumValue); + } + } + + delete p->stringValues; + delete p->numericValues; + delete p->mostFrequentStr; + memset(p, 0, sizeof(*p)); +} + /************************************************************************/ /* OGRSQLiteRegisterSQLFunctionsCommon() */ /************************************************************************/ @@ -360,6 +634,22 @@ static OGRSQLiteExtensionData *OGRSQLiteRegisterSQLFunctionsCommon(sqlite3 *hDB) nullptr, OGRSQLITE_STDDEV_Step, OGRSQLITE_STDDEV_SAMP_Finalize); + sqlite3_create_function(hDB, "median", 1, UTF8_INNOCUOUS, nullptr, nullptr, + OGRSQLITE_Percentile_Step, + OGRSQLITE_Percentile_Finalize); + + sqlite3_create_function(hDB, "percentile", 2, UTF8_INNOCUOUS, nullptr, + nullptr, OGRSQLITE_Percentile_Step, + OGRSQLITE_Percentile_Finalize); + + sqlite3_create_function( + hDB, "percentile_cont", 2, UTF8_INNOCUOUS, + const_cast("percentile_cont"), // any non-null ptr + nullptr, OGRSQLITE_Percentile_Step, OGRSQLITE_Percentile_Finalize); + + sqlite3_create_function(hDB, "mode", 1, UTF8_INNOCUOUS, nullptr, nullptr, + OGRSQLITE_Mode_Step, OGRSQLITE_Mode_Finalize); + pData->SetRegExpCache(OGRSQLiteRegisterRegExpFunction(hDB)); return pData;