Skip to content

Commit

Permalink
fixed issues with copying arg data
Browse files Browse the repository at this point in the history
  • Loading branch information
jrenaud90 committed Dec 2, 2024
1 parent 29553ed commit ef3ace0
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 24 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
New & Changes:
* `MAX_STEP` can now be cimported from the top-level of CyRK: `from CyRK cimport MAX_STEP`
* Changed uses of `void*` to `char*` for both diffeq additional args and pre-eval outputs. The signature of these functions has changed, please review documentation for correct usage.
* Added new diagnostic info display tool to `cysolve_ivp` and `pysolve_ivp` output that you can access with `<result>.print_diagnostics()`.

Fixes:
* Fixed issue with `cysolve_ivp` (`pysolve_ivp` did not have this bug) where additional args are passed to diffeq _and_ dense output is on _and_ extra output is captured.
Expand All @@ -16,6 +17,7 @@ Fixes:

Tests:
* Fixed tests where additional args were not being used.
* Fixed issue with diffeq test 5.

Documentation:
* Updated the "Advanced CySolver.md" documentation that was out of date.
Expand Down
2 changes: 2 additions & 0 deletions CyRK/cy/cysolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ CySolverResult::CySolverResult(
{
this->retain_solver = true;
}
// TEMP:
this->retain_solver = true;

// Get solution class ready to go
this->reset();
Expand Down
4 changes: 4 additions & 0 deletions CyRK/cy/cysolve.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "cysolve.hpp"
#include <exception>

#include <cstdio>

void baseline_cysolve_ivp_noreturn(
std::shared_ptr<CySolverResult> solution_sptr,
DiffeqFuncType diffeq_ptr,
Expand Down Expand Up @@ -32,6 +34,8 @@ void baseline_cysolve_ivp_noreturn(
const double t_start = t_span_ptr[0];
const double t_end = t_span_ptr[1];
const bool direction_flag = t_start <= t_end ? true : false;
const bool forward = direction_flag == true;
printf("t_start = %e; t_end = %e; direction_flag = %d; forward = %d\n", t_start, t_end, direction_flag, forward);
const bool t_eval_provided = t_eval ? true : false;

// Get new expected size
Expand Down
29 changes: 21 additions & 8 deletions CyRK/cy/cysolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ CySolverBase::CySolverBase(
t_start(t_start),
t_end(t_end),
diffeq_ptr(diffeq_ptr),
size_of_args(size_of_args),
len_t_eval(len_t_eval),
num_extra(num_extra),
use_dense_output(use_dense_output),
Expand All @@ -62,19 +63,31 @@ CySolverBase::CySolverBase(
this->storage_sptr->update_message("CySolverBase Initializing.");

// Build storage for args
if (args_ptr && (size_of_args > 0))
if (args_ptr && (this->size_of_args > 0))
{
// Allocate memory for the size of args.
// Store void pointer to it.
printf("Pre resize\n");
this->args_char_vec.resize(size_of_args);
this->args_ptr = this->args_char_vec.data();
printf("Pre resize; size = %d\n", this->size_of_args);
printf("Pre resize; VECTOR size = %d\n", this->args_char_vec.size());
this->args_char_vec.resize(this->size_of_args);

for (size_t i = 0; i < this->size_of_args; i++)
{
printf("\t %x\n", args_ptr[i] & 0xff);
}


// Copy over contents of arg
char* args_in_as_char_ptr = (char*)args_ptr;
printf("Pre Copy Over: arg_in_char = %p; sizeof = %d; args_in_char+size = %p\n", args_in_as_char_ptr, size_of_args, args_in_as_char_ptr + size_of_args);
this->args_char_vec.insert(this->args_char_vec.begin(), args_in_as_char_ptr, args_in_as_char_ptr + size_of_args);
// std::memcpy(this->args_ptr, args_ptr, size_of_args);
printf("Pre Copy Over: arg_in_char = %p; sizeof = %d; args_in_char+size = %p\n", args_ptr, this->size_of_args, args_ptr + this->size_of_args);
// this->args_char_vec.insert(this->args_char_vec.begin(), args_ptr, args_ptr + this->size_of_args);
this->args_ptr = this->args_char_vec.data();
std::memcpy(this->args_ptr, args_ptr, this->size_of_args);

printf("Vector array size = %d\n", this->args_char_vec.size());
for (size_t i = 0; i < this->args_char_vec.size(); i++)
{
printf("\t %x\n", this->args_char_vec[i] & 0xff);
}

printf("Post\n");
}
Expand Down
6 changes: 3 additions & 3 deletions CyRK/cy/cysolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ class CySolverBase {

// Attributes
protected:
// ** Attributes **

// Time variables
double t_tmp = 0.0;
double t_delta = 0.0;
Expand All @@ -103,10 +101,13 @@ class CySolverBase {
double num_y_dbl = 0.0;
double num_y_sqrt = 0.0;


public:
// Integration step information
size_t max_num_steps = 0;

// Additional arguments for the diffeq are stored locally in a char dynamic vector.
size_t size_of_args = 0;
std::vector<char> args_char_vec = std::vector<char>();
char* args_ptr = nullptr;

Expand Down Expand Up @@ -134,7 +135,6 @@ class CySolverBase {
// Dense (Interpolation) Attributes
bool use_dense_output = false;

public:
// PySolver Attributes
bool use_pysolver = false;
DiffeqMethod py_diffeq_method = nullptr;
Expand Down
3 changes: 3 additions & 0 deletions CyRK/cy/cysolver_api.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ cdef extern from "cysolver.cpp" nogil:
size_t num_dy
size_t num_y
shared_ptr[CySolverResult] storage_ptr
size_t size_of_args
vector[char] args_char_vec
char* args_ptr
size_t len_t
double t_now
vector[double] y_now
Expand Down
65 changes: 65 additions & 0 deletions CyRK/cy/cysolver_api.pyx
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# distutils: language = c++
# cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True, initializedcheck=False
from libc.math cimport floor

import numpy as np
cimport numpy as cnp
cnp.import_array()

from libc.stdio cimport printf

# =====================================================================================================================
# Import CySolverResult (container for integration results)
# =====================================================================================================================
Expand Down Expand Up @@ -51,6 +55,67 @@ cdef class WrapCySolverResult:
self.cyresult_ptr.call_vectorize(t_array_ptr, len_t, y_interp_ptr)
return y_interp_array.reshape(len_t, self.cyresult_ptr.num_dy).T

def print_diagnostics(self):

cdef str diagnostic_str = ''
from CyRK import __version__
cdef str direction_str = 'Forward'
if self.cyresult_ptr.direction_flag == 0:
direction_str = 'Backward'

diagnostic_str += f'CyRK v{__version__} - WrapCySolverResult Diagnostic.\n'
diagnostic_str += f'\n----------------------------------------------------\n'
diagnostic_str += f'# of y: {self.num_y}.\n'
diagnostic_str += f'# of dy: {self.num_dy}.\n'
diagnostic_str += f'Success: {self.success}.\n'
diagnostic_str += f'Error Code: {self.error_code}.\n'
diagnostic_str += f'Size: {self.size}.\n'
diagnostic_str += f'Steps Taken: {self.steps_taken}.\n'
diagnostic_str += f'Message:\n\t{self.message}\n'
diagnostic_str += f'\n----------------- CySolverResult -------------------\n'
diagnostic_str += f'Capture Extra: {self.cyresult_ptr.capture_extra}.\n'
diagnostic_str += f'Capture Dense Output: {self.cyresult_ptr.capture_dense_output}.\n'
diagnostic_str += f'Integration Direction: {direction_str}.\n'
diagnostic_str += f'Integration Method: {self.cyresult_ptr.integrator_method}.\n'
diagnostic_str += f'# of Interpolates: {self.cyresult_ptr.num_interpolates}.\n'

cdef CySolverBase* cysolver = self.cyresult_ptr.solver_uptr.get()
cdef size_t num_y
cdef size_t num_dy
cdef size_t i
cdef size_t args_size
cdef size_t args_size_dbls
cdef double* args_dbl_ptr
if cysolver:
num_y = cysolver.num_y
num_dy = cysolver.num_dy
diagnostic_str += f'\n------------------ CySolverBase --------------------\n'
diagnostic_str += f'Status: {cysolver.status}.\n'
diagnostic_str += f'# of y: {num_y}.\n'
diagnostic_str += f'# of dy: {num_dy}.\n'
diagnostic_str += f'PySolver: {cysolver.use_pysolver}.\n'
diagnostic_str += f't_now: {cysolver.t_now}.\n'
diagnostic_str += f'y_now:\n'
for i in range(num_y):
diagnostic_str += f'\ty{i} = {cysolver.y_now[i]:0.5e}.\n'
diagnostic_str += f'dy_now:\n'
for i in range(num_dy):
diagnostic_str += f'\tdy{i} = {cysolver.dy_now[i]:0.5e}.\n'
args_size = cysolver.size_of_args
args_size_dbls = <size_t>floor(args_size / sizeof(double))
args_dbl_ptr = <double*>cysolver.args_ptr
diagnostic_str += f'args size (bytes): {args_size}.\n'
diagnostic_str += f'args size (doubles): {args_size_dbls}.\n'
if args_size_dbls > 0:
diagnostic_str += f'args (as doubles):\n'
for i in range(args_size_dbls):
diagnostic_str += f'\targ{i} = {args_dbl_ptr[i]:0.5e}.\n'
else:
diagnostic_str += f'CySolverBase instance was deleted or voided.\n'

diagnostic_str += f'\n-------------- Diagnostic Complete -----------------\n'
print(diagnostic_str)

@property
def success(self):
return self.cyresult_ptr.success
Expand Down
30 changes: 18 additions & 12 deletions CyRK/cy/cysolver_test.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ cdef struct ArbitraryArgStruct:
double l
# Let's make sure it has something in the middle that is not a double for size checks.
cpp_bool cause_fail
cpp_bool checker
double m
double g

Expand All @@ -153,6 +154,8 @@ cdef void arbitrary_arg_test(double* dy_ptr, double t, double* y_ptr, char* args
cdef double m = arb_args_ptr.m
cdef double g = arb_args_ptr.g
cdef cpp_bool cause_fail = arb_args_ptr.cause_fail
cdef cpp_bool checker = arb_args_ptr.checker
printf("arbitrary_arg_test:: l = %e, m = %e, g = %e; cause_fail = %d\n", l, m, g, cause_fail)

cdef double coeff_1 = (-3. * g / (2. * l))
cdef double coeff_2 = (3. / (m * l**2))
Expand Down Expand Up @@ -206,10 +209,10 @@ cdef void pendulum_preeval_diffeq(double* dy_ptr, double t, double* y_ptr, char*
cdef double* pre_eval_storage_ptr = &pre_eval_storage[0]

# Cast storage to void so we can call function
cdef char* pre_eval_storage_void_ptr = <char*>pre_eval_storage_ptr
cdef char* pre_eval_storage_char_ptr = <char*>pre_eval_storage_ptr

# Call Pre-Eval Function
pre_eval_func(pre_eval_storage_void_ptr, t, y_ptr, args_ptr)
pre_eval_func(pre_eval_storage_char_ptr, t, y_ptr, args_ptr)

cdef double y0 = y_ptr[0]
cdef double y1 = y_ptr[1]
Expand Down Expand Up @@ -245,7 +248,7 @@ def cy_extra_output_tester():
int_method,
1.0e-4,
1.0e-5,
args_ptr,
<char*>args_ptr,
arg_size,
num_extra
)
Expand All @@ -264,7 +267,10 @@ def cy_extra_output_tester():
e2 = y_interp_ptr[4]
e3 = y_interp_ptr[5]

# Corrupt or otherwise mess up the arg pointer
# Corrupt or otherwise change up the arg pointer
args_ptr[0] = -99.0
args_ptr[1] = -99.0
args_ptr[2] = -99.0
args_ptr = <double*>realloc(args_ptr, sizeof(double)*3000)
cdef size_t i
for i in range(3000):
Expand All @@ -275,12 +281,12 @@ def cy_extra_output_tester():
result.get().call(check_t, y_interp_ptr)
cdef bint passed = True

passed = dy1 == y_interp_ptr[0]
passed = dy2 == y_interp_ptr[1]
passed = dy3 == y_interp_ptr[2]
passed = e1 == y_interp_ptr[3]
passed = e2 == y_interp_ptr[4]
passed = e3 == y_interp_ptr[5]
assert dy1 == y_interp_ptr[0]
assert dy2 == y_interp_ptr[1]
assert dy3 == y_interp_ptr[2]
assert e1 == y_interp_ptr[3]
assert e2 == y_interp_ptr[4]
assert e3 == y_interp_ptr[5]

return passed

Expand Down Expand Up @@ -357,7 +363,7 @@ def cytester(
cdef double[10] args_arr
cdef double* args_ptr_dbl = &args_arr[0]
# Abitrary arg test requires a ArbitraryArgStruct class instance to be passed in
cdef ArbitraryArgStruct arb_arg_struct = ArbitraryArgStruct(1.0, False, 1.0, 9.81)
cdef ArbitraryArgStruct arb_arg_struct = ArbitraryArgStruct(1.0, False, True, 1.0, 9.81)

# Check if generic testing was requested.
printf("Cytester pt3\n")
Expand Down Expand Up @@ -419,7 +425,7 @@ def cytester(
y0_ptr[0] = 10.0
y0_ptr[1] = 5.0
t_span_ptr[0] = 0.0
t_span_ptr[0] = 15.0
t_span_ptr[1] = 15.0
args_ptr_dbl[0] = 1.5
args_ptr_dbl[1] = 1.0
args_ptr_dbl[2] = 3.0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name='CyRK'
version = '0.12.0a0.dev4'
version = '0.12.0a0.dev5'
description='Runge-Kutta ODE Integrator Implemented in Cython and Numba.'
authors= [
{name = 'Joe P. Renaud', email = 'joe.p.renaud@gmail.com'}
Expand Down

0 comments on commit ef3ace0

Please sign in to comment.