Skip to content

Commit

Permalink
Found issue with arg pointer going out of scope.
Browse files Browse the repository at this point in the history
  • Loading branch information
jrenaud90 committed Nov 27, 2024
1 parent a8cf831 commit 8d28572
Show file tree
Hide file tree
Showing 2 changed files with 276 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CyRK/cy/cysolver_test.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def cytester(
cdef double[2] t_span_arr
cdef double* t_span_ptr = &t_span_arr[0]

cdef int num_extra = 0
cdef size_t num_extra = 0
cdef DiffeqFuncType diffeq = NULL
cdef PreEvalFunc pre_eval_func = NULL
if diffeq_number == 0:
Expand Down
275 changes: 275 additions & 0 deletions Tests/Untitled.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "3f8b4f01-9b64-4a55-b902-29f4c5e7ec15",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.11.6a0.dev0\n"
]
}
],
"source": [
"from CyRK import pysolve_ivp, WrapCySolverResult, __version__\n",
"\n",
"print(__version__)\n",
"\n",
"import Cython"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "183cb38a-cb77-4e68-8c0e-3457d4772c22",
"metadata": {},
"outputs": [],
"source": [
"%load_ext cython"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c5af7593-8cc6-4816-8ca7-dc7f4f649936",
"metadata": {},
"outputs": [],
"source": [
"%%cython -a -f\n",
"# distutils: language = c++\n",
"# cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True, initializedcheck=False\n",
"\n",
"import numpy as np\n",
"cimport numpy as np\n",
"np.import_array()\n",
"\n",
"from libcpp.vector cimport vector\n",
"\n",
"from CyRK cimport cysolve_ivp, CySolveOutput, CySolverResult, DiffeqFuncType\n",
"from CyRK.cy.cysolver_test cimport lotkavolterra_diffeq\n",
"\n",
"\n",
"cdef DiffeqFuncType diffeq = lotkavolterra_diffeq\n",
"\n",
"\n",
"cdef CySolveOutput test_1():\n",
"\n",
" cdef double[2] t_span = [0., 10.]\n",
" cdef double* t_span_ptr = &t_span[0]\n",
"\n",
" cdef double[2] y0 = [10., 5.]\n",
" cdef double* y0_ptr = &y0[0]\n",
"\n",
" cdef size_t num_y = 2\n",
"\n",
" cdef vector[double] args = vector[double]()\n",
" args.resize(3)\n",
" args[0] = 1.0\n",
" args[1] = 1.0\n",
" args[2] = 9.81\n",
" cdef double* args_ptr = args.data()\n",
"\n",
" cdef CySolveOutput result = cysolve_ivp(\n",
" diffeq,\n",
" t_span_ptr,\n",
" y0_ptr,\n",
" num_y,\n",
" 1,\n",
" 1.0e-5,\n",
" 1.0e-6,\n",
" args_ptr,\n",
" 0,\n",
" 1_000_000,\n",
" 2_000,\n",
" True,\n",
" NULL,\n",
" 0,\n",
" NULL,\n",
" NULL,\n",
" NULL,\n",
" 10_000,\n",
" 0.0,\n",
" 100\n",
" )\n",
" return result\n",
"\n",
"from libc.stdlib cimport malloc, free, realloc\n",
"\n",
"cdef CySolveOutput test_2():\n",
"\n",
" cdef double[2] t_span = [0., 10.]\n",
" cdef double* t_span_ptr = &t_span[0]\n",
"\n",
" cdef double[2] y0 = [10., 5.]\n",
" cdef double* y0_ptr = &y0[0]\n",
"\n",
" cdef size_t num_y = 2\n",
"\n",
" cdef double* args_ptr = <double*>malloc(sizeof(double)*3)\n",
" args_ptr[0] = 1.0\n",
" args_ptr[1] = 1.0\n",
" args_ptr[2] = 9.81\n",
"\n",
" cdef CySolveOutput result = cysolve_ivp(\n",
" diffeq,\n",
" t_span_ptr,\n",
" y0_ptr,\n",
" num_y,\n",
" 1,\n",
" 1.0e-5,\n",
" 1.0e-6,\n",
" args_ptr,\n",
" 0,\n",
" 1_000_000,\n",
" 2_000,\n",
" True,\n",
" NULL,\n",
" 0,\n",
" NULL,\n",
" NULL,\n",
" NULL,\n",
" 10_000,\n",
" 0.0,\n",
" 100\n",
" )\n",
" realloc(args_ptr, sizeof(double)*3000)\n",
" cdef size_t i \n",
" for i in range(3000):\n",
" args_ptr[i] = -1.0\n",
" free(args_ptr)\n",
" return result\n",
"\n",
"cdef CySolveOutput res_shptr\n",
"cdef CySolverResult* res\n",
"\n",
"from libc.stdio cimport printf\n",
"\n",
"cdef double[2] y_interp\n",
"cdef double* y_interp_ptr = &y_interp[0]\n",
"cdef size_t i\n",
"\n",
"printf(\"\\nTest 1\\n\")\n",
"for i in range(10):\n",
" printf(\"\\tSubTest Num = %d\\n\", i)\n",
" res_shptr = test_1()\n",
" res = res_shptr.get()\n",
" printf(\"Test 1 Success = %d\\n\", res.success)\n",
" \n",
" y_interp_ptr[0] = 0.0\n",
" y_interp_ptr[1] = 0.0\n",
" \n",
" printf(\"Test 1; calling\\n\")\n",
" res.call(4.38, y_interp)\n",
" printf(\"Test 1; Call Finished. y0 = %e; y1 = %e\\n\", y_interp[0], y_interp[1])\n",
"\n",
"printf(\"\\nTest 2\\n\")\n",
"for i in range(10):\n",
" printf(\"\\tSubTest Num = %d\\n\", i)\n",
" res_shptr = test_2()\n",
" res = res_shptr.get()\n",
" printf(\"Test 2 Success = %d\\n\", res.success)\n",
" \n",
" y_interp_ptr[0] = 0.0\n",
" y_interp_ptr[1] = 0.0\n",
" \n",
" printf(\"Test 2; calling\\n\")\n",
" res.call(4.38, y_interp)\n",
" printf(\"Test 2; Call Finished. y0 = %e; y1 = %e\\n\", y_interp[0], y_interp[1])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "775687b3-9c4c-4e79-9b4f-efe89dad2931",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True Integration completed without issue.\n",
"[[ 7.03679027]\n",
" [124.35084124]\n",
" [ -0.24350841]\n",
" [ -0.85926419]]\n"
]
}
],
"source": [
"def test_1():\n",
" result = \\\n",
" pysolve_ivp(diffeq_args_extra, time_span, initial_conds,\n",
" method=\"RK45\",\n",
" args=args, rtol=rtol, atol=atol,\n",
" dense_output=True,\n",
" num_extra=2, first_step=0.0, max_step=1000.0,\n",
" pass_dy_as_arg=True)\n",
" return result\n",
"res_1 = test_1()\n",
"print(res_1.success, res_1.message)\n",
"print(res_1(5.2))"
]
},
{
"cell_type": "code",
"execution_count": 372,
"id": "5577bce7-f5fe-4245-9766-954eabc431ae",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True Integration completed without issue.\n",
"[[ 7.03679027]\n",
" [124.35084124]\n",
" [ -0.24350841]\n",
" [ -0.85926419]]\n"
]
}
],
"source": [
"def test_2():\n",
" args_2 = (0.01, 0.02)\n",
" \n",
" result = \\\n",
" pysolve_ivp(diffeq_args_extra, time_span, initial_conds,\n",
" method=\"RK45\",\n",
" args=args_2, rtol=rtol, atol=atol,\n",
" dense_output=True,\n",
" num_extra=2, first_step=0.0, max_step=1000.0,\n",
" pass_dy_as_arg=True)\n",
" del args_2\n",
" return result\n",
"res_2 = test_2()\n",
"print(res_2.success, res_2.message)\n",
"print(res_2(5.2))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 8d28572

Please sign in to comment.