From dc88e4a555e57bcb6ac85da31aabb95e4c6b23a7 Mon Sep 17 00:00:00 2001 From: Thync Date: Tue, 24 Sep 2024 15:39:51 +0700 Subject: [PATCH] Update: Hub Message Model --- pyproject.toml | 2 +- vdatafeed/ssi/hub.py | 32 ++------------------- vdatafeed/ssi/model.py | 48 +++++++++++++++++++++++++++++++- vdatafeed/utils/__init__.py | 2 +- vdatafeed/utils/model_handler.py | 2 +- 5 files changed, 53 insertions(+), 33 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7da09d3..7282002 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "vdatafeed" -version = "1.1.2" +version = "1.1.3" license = "MIT" description = "vDatafeed: A Python wrapper for Viet Nam Datafeed API" repository = "https://github.com/quant-vn/vdatafeed" diff --git a/vdatafeed/ssi/hub.py b/vdatafeed/ssi/hub.py index abeb03e..c9a78e2 100644 --- a/vdatafeed/ssi/hub.py +++ b/vdatafeed/ssi/hub.py @@ -3,6 +3,7 @@ from urllib.parse import urlencode from .constant import HUB_URL, HUB +from .model import TradeTick, QuoteTick from ..interface_datafeed_hub import IDatafeedHUB from ..utils import SocketListener, request_handler @@ -85,41 +86,14 @@ async def listen(self, args, on_trade_message, on_quote_message): if "A" not in i or not i["A"]: continue msg = json.loads(json.loads(i["A"][0]).get("Content")) - # print(msg) if msg.get("symbol") not in last_vol: last_vol[msg.get("symbol")] = msg.get("LastVol") else: if last_vol[msg.get("symbol")] == msg.get("LastVol"): - _l = range(1, 11) - _rev_l = list(reversed(_l)) - msg = { - "datetime": " ".join([ - msg.get("TradingDate"), msg.get("Time") - ]), - "symbol": msg.get("Symbol"), - "ce": msg.get("Ceiling"), - "fl": msg.get("Floor"), - "re": msg.get("RefPrice"), - "bid_price": [msg.get(f"BidPrice{i}") for i in _rev_l], - "bid_vol": [msg.get(f"BidVol{i}") for i in _rev_l], - "ask_price": [msg.get(f"AskPrice{i}") for i in _l], - "ask_vol": [msg.get(f"AskVol{i}") for i in _l] - } - on_quote_message(msg) + on_quote_message(QuoteTick(**msg)) continue last_vol[msg.get("symbol")] = msg.get("LastVol") - msg = { - "datetime": " ".join([msg.get("TradingDate"), msg.get("Time")]), - "symbol": msg.get("Symbol"), - "ce": msg.get("Ceiling"), - "fl": msg.get("Floor"), - "re": msg.get("RefPrice"), - "price": msg.get("LastPrice"), - "vol": msg.get("LastVol"), - "t_vol": msg.get("TotalVol"), - "t_val": msg.get("TotalVal"), - } - on_trade_message(msg) + on_trade_message(TradeTick(**msg)) except Exception as e: print(f" Connection error: {e}") except Exception as e: diff --git a/vdatafeed/ssi/model.py b/vdatafeed/ssi/model.py index 2fe2c3d..1b193c7 100644 --- a/vdatafeed/ssi/model.py +++ b/vdatafeed/ssi/model.py @@ -1,6 +1,6 @@ """ Model for SSI datafeed """ from typing import Optional -from ..utils import BaseModel, Field, AliasChoices +from ..utils import BaseModel, Field, AliasChoices, model_validator class SecuritiesInfo(BaseModel): @@ -290,3 +290,49 @@ class IntradayOHLC(BaseModel): close: Optional[str] = Field(validation_alias=AliasChoices('close', 'Close')) vol: Optional[str] = Field(validation_alias=AliasChoices('vol', 'Volume')) val: Optional[str] = Field(validation_alias=AliasChoices('val', 'Value')) + + +class TradeTick(BaseModel): + datetime: Optional[str] = None + symbol: Optional[str] = Field(validation_alias=AliasChoices('symbol', 'Symbol')) + ceiling: Optional[float] = Field(validation_alias=AliasChoices('ceiling', 'Ceiling')) + floor: Optional[float] = Field(validation_alias=AliasChoices('floor', 'Floor')) + ref_price: Optional[float] = Field(validation_alias=AliasChoices('ref_price', 'RefPrice')) + price: Optional[float] = Field(validation_alias=AliasChoices('price', 'LastPrice')) + vol: Optional[float] = Field(validation_alias=AliasChoices('vol', 'LastVol')) + total_vol: Optional[float] = Field(validation_alias=AliasChoices('total_vol', 'TotalVol')) + total_val: Optional[float] = Field(validation_alias=AliasChoices('total_val', 'TotalVal')) + + @model_validator(mode='before') + def set_custom_field(cls, values): + values['datetime'] = ' '.join([ + "/".join(reversed(values.get('TradingDate').split("/"))), + values.get('Time') + ]) + return values + + +class QuoteTick(BaseModel): + datetime: Optional[str] = None + symbol: Optional[str] = Field(validation_alias=AliasChoices('symbol', 'Symbol')) + ceiling: Optional[float] = Field(validation_alias=AliasChoices('ceiling', 'Ceiling')) + floor: Optional[float] = Field(validation_alias=AliasChoices('floor', 'Floor')) + ref_price: Optional[float] = Field(validation_alias=AliasChoices('ref_price', 'RefPrice')) + bid_price: Optional[list] = [] + bid_vol: Optional[list] = [] + ask_price: Optional[list] = [] + ask_vol: Optional[list] = [] + + @model_validator(mode='before') + def set_custom_field(cls, values): + values['datetime'] = ' '.join([ + "/".join(reversed(values.get('TradingDate').split("/"))), + values.get('Time') + ]) + _l = range(1, 11) + _rev_l = list(reversed(_l)) + values['bid_price'] = [values.get(f"BidPrice{i}") for i in _rev_l] + values['bid_vol'] = [values.get(f"BidVol{i}") for i in _rev_l], + values['ask_price'] = [values.get(f"AskPrice{i}") for i in _l], + values['ask_vol'] = [values.get(f"AskVol{i}") for i in _l] + return values diff --git a/vdatafeed/utils/__init__.py b/vdatafeed/utils/__init__.py index a8138a4..28d9aa7 100644 --- a/vdatafeed/utils/__init__.py +++ b/vdatafeed/utils/__init__.py @@ -1,5 +1,5 @@ from .enum_handler import EnumHandler # noqa: F401 -from .model_handler import BaseModel, AliasChoices, Field # noqa: F401 +from .model_handler import BaseModel, AliasChoices, Field, model_validator # noqa: F401 from .request_handler import request_handler # noqa: F401 from .socket_handler import SocketListener # noqa: F401 from .jwt_handler import jwt_handler # noqa: F401 diff --git a/vdatafeed/utils/model_handler.py b/vdatafeed/utils/model_handler.py index ecd49bb..4342b4a 100644 --- a/vdatafeed/utils/model_handler.py +++ b/vdatafeed/utils/model_handler.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel as BM, ConfigDict, Field, AliasChoices # noqa: F401 +from pydantic import BaseModel as BM, ConfigDict, Field, AliasChoices, model_validator # noqa: F401 class BaseModel(BM):