Skip to content

Commit

Permalink
[python] add type hints to logging functions in basic.py (#4527)
Browse files Browse the repository at this point in the history
* [python] add type hints to logging functions in basic.py

* add hints on wrapper
  • Loading branch information
jameslamb authored Aug 19, 2021
1 parent 73f7d5d commit c65a2e3
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from os.path import getsize
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union

import numpy as np
import scipy.sparse
Expand All @@ -34,17 +34,17 @@ def _get_sample_count(total_nrow: int, params: str):


class _DummyLogger:
def info(self, msg):
def info(self, msg: str) -> None:
print(msg)

def warning(self, msg):
def warning(self, msg: str) -> None:
warnings.warn(msg, stacklevel=3)


_LOGGER = _DummyLogger()
_LOGGER: Union[_DummyLogger, Logger] = _DummyLogger()


def register_logger(logger):
def register_logger(logger: Logger) -> None:
"""Register custom logger.
Parameters
Expand All @@ -58,12 +58,12 @@ def register_logger(logger):
_LOGGER = logger


def _normalize_native_string(func):
def _normalize_native_string(func: Callable[[str], None]) -> Callable[[str], None]:
"""Join log messages from native library which come by chunks."""
msg_normalized = []
msg_normalized: List[str] = []

@wraps(func)
def wrapper(msg):
def wrapper(msg: str) -> None:
nonlocal msg_normalized
if msg.strip() == '':
msg = ''.join(msg_normalized)
Expand All @@ -75,20 +75,20 @@ def wrapper(msg):
return wrapper


def _log_info(msg):
def _log_info(msg: str) -> None:
_LOGGER.info(msg)


def _log_warning(msg):
def _log_warning(msg: str) -> None:
_LOGGER.warning(msg)


@_normalize_native_string
def _log_native(msg):
def _log_native(msg: str) -> None:
_LOGGER.info(msg)


def _log_callback(msg):
def _log_callback(msg: bytes) -> None:
"""Redirect logs from native library into Python."""
_log_native(str(msg.decode('utf-8')))

Expand Down

0 comments on commit c65a2e3

Please sign in to comment.