Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions pykis/kis.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def virtual(self) -> bool:
"""웹소켓 클라이언트"""
_keep_token: Path | None
"""API 접속 토큰 자동 저장 경로"""
_sessions: dict[Literal["real", "virtual"], requests.Session]
"""API 세션"""

@property
def keep_token(self) -> bool:
Expand Down Expand Up @@ -420,6 +422,13 @@ def __init__(
if isinstance(virtual_token, KisAccessToken)
else KisAccessToken.load(virtual_token) if self.virtual and virtual_token else None
)
self._sessions = {
"real": requests.Session(),
"virtual": requests.Session(),
}

for session in self._sessions.values():
session.headers.update({"User-Agent": USER_AGENT})

if keep_token:
if keep_token is True:
Expand Down Expand Up @@ -516,44 +525,43 @@ def request(
elif body is None:
body = {}

if headers is None:
headers = {}
request_headers = headers.copy() if headers else {}

if domain is None:
domain = "virtual" if self.virtual else "real"

session = self._sessions[domain]

if appkey_location:
appkey = self.appkey if domain == "real" else self.virtual_appkey

if appkey is None:
raise ValueError("모의도메인 AppKey가 없습니다.")

appkey.build(headers if appkey_location == "header" else body)
appkey.build(request_headers if appkey_location == "header" else body)

if form is not None:
if form_location is None:
form_location = "params" if method == "GET" else "body"

dist = headers if form_location == "header" else params if form_location == "params" else body
dist = request_headers if form_location == "header" else params if form_location == "params" else body

for f in form:
if f is not None:
f.build(dist)

headers["User-Agent"] = USER_AGENT

rate_limit = self._rate_limiters[domain]

while True:
rate_limit.acquire(blocking_callback=self._rate_limit_exceeded)

if auth:
(self.token if domain == "real" else self.primary_token).build(headers)
(self.token if domain == "real" else self.primary_token).build(request_headers)

resp = requests.request(
resp = session.request(
method=method,
url=urljoin(REAL_DOMAIN if domain == "real" else VIRTUAL_DOMAIN, path),
headers=headers,
headers=request_headers,
params=params,
json=body,
)
Expand All @@ -563,7 +571,7 @@ def request(

try:
data = resp.json()
except:
except Exception:
data = None

error_code = data.get("msg_cd") if data is not None else None
Expand Down Expand Up @@ -731,6 +739,15 @@ def websocket(self) -> KisWebsocketClient:

return self._websocket

def close(self) -> None:
"""API 세션을 종료합니다."""
for session in self._sessions.values():
session.close()

def __del__(self) -> None:
"""API 세션을 종료합니다."""
self.close()

from pykis.api.stock.trading_hours import trading_hours
from pykis.scope.account import account
from pykis.scope.stock import stock