Skip to content

Commit

Permalink
fix: typing
Browse files Browse the repository at this point in the history
  • Loading branch information
severinsimmler committed Dec 19, 2022
1 parent 3c8a018 commit 6682b19
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 46 deletions.
72 changes: 31 additions & 41 deletions chaine/_core/crf.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions chaine/_core/crf.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ from libcpp.string cimport string
import os

from chaine.logging import Logger
from chaine.typing import Filepath, Labels, Sequence, Union
from chaine.typing import Filepath, Labels, Sequence

LOGGER = Logger(__name__)

Expand Down Expand Up @@ -110,7 +110,7 @@ cdef class Trainer:

self._trainer.append(to_seq(sequence), labels, group)

def translate_params(self, kwargs: dict[str, Union[str, int, float, bool]]):
def translate_params(self, kwargs: dict[str, str | int | float | bool]):
return {
self.kwarg2param.get(kwarg, kwarg): value
for kwarg, value in kwargs.items()
Expand All @@ -124,19 +124,19 @@ cdef class Trainer:
if not self._trainer.select(algorithm, "crf1d"):
raise ValueError(f"{algorithm} is no available algorithm")

def set_params(self, params: dict[str, Union[str, int, float, bool]]):
def set_params(self, params: dict[str, str | int | float | bool]):
for param, value in params.items():
self.set_param(param, value)

def set_param(self, param: str, value: Union[str, int, float, bool]):
def set_param(self, param: str, value: str | int | float | bool):
if isinstance(value, bool):
value = int(value)
self._trainer.set(param, str(value))

def get_param(self, param: str):
return self.cast_parameter(param, self._trainer.get(param))

def cast_parameter(self, param: str, value: Union[str, int, float, bool]):
def cast_parameter(self, param: str, value: str | int | float | bool):
if param in self._parameter_types:
return self._parameter_types[param](value)
return value
Expand Down

0 comments on commit 6682b19

Please sign in to comment.