-
Notifications
You must be signed in to change notification settings - Fork 1
/
fixed_point.py
413 lines (331 loc) · 11.2 KB
/
fixed_point.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
from threading import Thread
from collections.abc import Callable, Iterable, Mapping
from threading import Thread
from typing import Any
from numba import jit
from sympy import (
Symbol,
diff,
solve,
nroots,
solve,
Eq,
preview,
init_printing,
sympify,
)
from sympy.parsing.sympy_parser import parse_expr
from sympy.polys.polyerrors import PolynomialError
class NoAssumption(RuntimeError):
"""A custom exception class; used when no assumption can be found."""
pass
class ImeginaryNumber(RuntimeError, TypeError, ValueError):
"""An error class related to imaginary numbers."""
pass
def take_diff(expr):
"""
Computes the first derivative of a mathematical expression.
Args:
expr (sympy.Basic): The mathematical expression to differentiate.
Returns:
tuple: The first derivative and the symbol of the independent variable.
"""
x = Symbol("x")
return diff(expr, x), x
def find_extremums(expr) -> list:
"""
Finds the extremum points of a mathematical expression.
Args:
expr (sympy.Basic): The mathematical expression to find extremums.
Returns:
list: List of extremum points.
"""
first_derivative, x = take_diff(expr)
try:
extremums: list = sorted(nroots(first_derivative, n=6))
return extremums
except PolynomialError:
return [float(first_derivative)]
def find_lims(expr, extremums: list, tolerance: float = 1e-6) -> list:
"""
Finds the limits of a mathematical expression at specific points.
Args:
expr (sympy.Basic): The mathematical expression to find limits.
extremums (list): List of points where limits will be calculated.
tolerance (float): Tolerance for limit calculation.
Returns:
list: List of limits at the specified points.
"""
lims: list = [
(
expr.evalf(subs={"x": extremum - tolerance}),
expr.evalf(subs={"x": extremum - 2 * tolerance}),
)
for extremum in extremums
]
lims.append(
(
expr.evalf(subs={"x": extremums[-1] + tolerance}),
expr.evalf(subs={"x": extremums[-1] + 2 * tolerance}),
)
)
return lims
def find_signs(lims: list) -> list:
"""
Determines the signs of a list of limits.
Args:
lims (list): List of limits.
Returns:
list: List of signs corresponding to the given limits.
"""
signs = ["-" if e2 > e and e2 > 0 else "-" for e, e2 in lims[:-1]]
signs.append(*["+" if e2 > e and e2 > 0 else "-" for e, e2 in lims[-1:]])
return signs
def make_expression(func_repr: str) -> object:
"""
Creates a symbolic expression from its string representation.
Args:
func_repr (str): String representation of the mathematical expression.
Returns:
object: The symbolic expression.
"""
return parse_expr(func_repr)
def find_extremums_first_diff(expr) -> list:
"""
Finds the extremum points using the second derivative of a mathematical expression.
Args:
expr (sympy.Basic): The mathematical expression.
Returns:
list: List of extremum points.
"""
second_derivative, x = take_diff(take_diff(expr)[0])
try:
extremums = sorted(nroots(second_derivative, n=6))
extremums = (
extremums if extremums != [] else [-second_derivative.evalf(subs={"x": 0})]
)
return extremums
except PolynomialError:
return [float(second_derivative)]
def find_up_down(lims: list):
"""
Determines whether the function is increasing or decreasing at specific points.
Args:
lims (list): List of limits.
Returns:
list: List of directions indicating whether the function is increasing ('u') or decreasing ('d').
"""
directions = ["u" if (e2 - e) < 0 else "d" for e2, e in lims]
return directions
def check_opposite_sign(signs: list, extremum: float) -> bool:
"""
Checks if there is an opposite sign around a given extremum.
Args:
signs (list): List of signs.
extremum (float): The extremum point.
Returns:
bool: True if there is an opposite sign, False otherwise.
"""
match signs[0]:
case "+":
return "-" in set(signs[1:]) and extremum > 0
case "-":
return "+" in set(signs[1:]) and extremum < 0
return True
def check_opposite_direction(directions: list, extremum: float):
"""
Checks if there is an opposite direction around a given extremum.
Args:
directions (list): List of directions.
extremum (float): The extremum point.
Returns:
bool: True if there is an opposite direction, False otherwise.
"""
if directions != []:
match directions[0]:
case "u":
return "d" in set(directions[1:]) and extremum > 0
case "d":
return "u" in set(directions[1:]) and extremum < 0
return True
def make_assumption(extremums: list, directions: list, signs: list) -> list:
"""
Makes assumptions about potential starting points for fixed-point iterations.
Args:
extremums (list): List of extremum points.
directions (list): List of directions indicating increasing ('u') or decreasing ('d').
signs (list): List of signs.
Returns:
list: List of assumed starting points for fixed-point iterations.
"""
points = []
for i, extremum in enumerate(extremums):
if check_opposite_direction(directions[i:], extremum) and check_opposite_sign(
signs[i:], extremum
):
points.append(extremum)
return points
@jit(forceobj=True, nogil=True)
def fixed_point_method(
expr, g_x, x0: float, max_iter: int = 100, tolerance: float = 1e-6
):
"""
Find the root of a function using the fixed-point method.
Parameters:
- func_repr: str
- x0: Initial guess
- max_iter: Maximum number of iterations
- tolerance: Tolerance for error
Returns:
- roots: Computed roots at each iteration
- iterations: Number of iterations performed
"""
result: list = []
error_1: float = 1.0
error_2: float = 1.0
imag: bool = False
while max_iter > 0 and error_1 > tolerance * 1000 and error_2 > tolerance * 1000:
x_new: float = g_x.evalf(subs={"y": x0})
if not sympify(x_new).is_real:
imag = True
break
error_1 = abs(x_new - x0)
error_2 = abs(expr.evalf(subs={"x": x_new}))
result.append(x_new)
x0 = x_new
max_iter -= 1
if imag:
return fixed_complex_point_iteration(expr, x0, max_iter, tolerance)
return result
def fixed_complex_point_iteration(
expr, x0: float, max_iter: int = 100, tolerance: float = 1e-6
) -> list:
"""
Finds the root of a function using the fixed-point method for complex roots.
Parameters:
- expr: Mathematical expression representing the function.
- x0: Initial guess.
- max_iter: Maximum number of iterations.
- tolerance: Tolerance for error.
Returns:
list: Computed roots at each iteration for complex roots.
"""
result: list = []
error_1: float = 1.0
while max_iter > 0 and error_1 > tolerance * 1000:
x_new: float = expr.evalf(subs={"x": x0})
error_1 = abs(x_new)
result.append(x0)
x0 = x_new
max_iter -= 1
return result
class ReturnThread(Thread):
def __init__(
self,
group: None = None,
target: Callable[..., object] | None = None,
name: str | None = None,
args: Iterable[Any] = ...,
kwargs: Mapping[str, Any] | None = None,
*,
daemon: bool | None = None,
) -> None:
"""
A custom thread class that captures the result of the target function.
Parameters:
- target: Target function to be executed in the thread.
- args: Arguments to be passed to the target function.
"""
super().__init__(target=target, args=args)
self.result = None
self.target: callable = target
self.args: Iterable[Any] = args
def run(self) -> None:
"""
Runs the target function in the thread and captures the result.
"""
self.result = self.target(*self.args)
def make_g_x(expr):
"""
Generates a list of functions from an expression by solving for the variable.
Args:
expr (sympy.Basic): The mathematical expression.
Returns:
list: List of functions obtained by solving the expression for the variable.
"""
x = Symbol("x")
y = Symbol("y")
eq = Eq(expr, y)
result: list = []
for solved_eq in solve(eq, x):
result.append(solved_eq)
return result
def make_func_imgs(expr, g_x: list) -> None:
"""
Creates images of the mathematical expressions.
Args:
expr (sympy.Basic): The original mathematical expression.
g_x (list): List of expressions representing the fixed-point iteration functions.
"""
for i, g in enumerate(g_x):
with open(f"g_x_{str(i)}.png", "wb") as outputfile:
preview(g, viewer="BytesIO", outputbuffer=outputfile)
with open(f"f_x.png", "wb") as outputfile:
preview(expr, viewer="BytesIO", outputbuffer=outputfile)
def main(func_repr: str, max_iter: int = 500, tolerance: float = 1e-6):
"""
Main function orchestrating the entire process of finding roots using fixed-point iteration.
Args:
func_repr (str): String representation of the mathematical expression.
max_iter (int): Maximum number of iterations.
tolerance (float): Tolerance for error.
Returns:
list: Computed roots using fixed-point iteration.
"""
expr = make_expression(func_repr=func_repr)
g_x = make_expression(func_repr=func_repr)
g_x: list = make_g_x(g_x)
extremum_first_diff: list = find_extremums_first_diff(expr)
extremums: list = find_extremums(expr)
lims: list = find_lims(expr, extremums, tolerance)
signs: list = find_signs(lims)
lims_second_diff: list = find_lims(expr, extremum_first_diff, tolerance)
directions: list = find_up_down(lims_second_diff)
assumptions: list = make_assumption(extremums, directions, signs)
if assumptions == []:
raise NoAssumption("Starting point did not be found, Use different function")
print(assumptions)
dicts_1: list = [
{
"target": fixed_point_method,
"args": (expr, g, assumption - 2 * tolerance, max_iter, tolerance),
}
for assumption in assumptions
for g in g_x
]
dicts_2: list = [
{
"target": fixed_point_method,
"args": (expr, g, assumption + 2 * tolerance, max_iter, tolerance),
}
for assumption in assumptions
for g in g_x
]
dicts_1.extend(dicts_2)
threads: list = [
ReturnThread(target=dict_["target"], args=dict_["args"]) for dict_ in dicts_1
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
results: list = [thread.result for thread in threads]
make_func_imgs(expr, g_x)
if None in results:
while None in results:
results.remove(None)
return results
if __name__ == "__main__":
init_printing(use_latex="mathjax")
# Find the root using the fixed-point method
print(main("x**2 + 2*x + 1"))