367 lines
12 KiB
Python
367 lines
12 KiB
Python
![]() |
import asyncio
|
||
|
import base64
|
||
|
import http
|
||
|
import random
|
||
|
import time
|
||
|
from urllib.parse import urlparse, parse_qs
|
||
|
|
||
|
import requests
|
||
|
import websockets
|
||
|
|
||
|
from lark_oapi.core.cache import ExpiringCache
|
||
|
from lark_oapi.core.const import UTF_8, FEISHU_DOMAIN
|
||
|
from lark_oapi.core.enum import LogLevel
|
||
|
from lark_oapi.core.json import JSON
|
||
|
from lark_oapi.core.log import logger
|
||
|
from lark_oapi.core.utils import Strings
|
||
|
from lark_oapi.event.dispatcher_handler import EventDispatcherHandler
|
||
|
from lark_oapi.ws.const import *
|
||
|
from lark_oapi.ws.enum import FrameType, MessageType
|
||
|
from lark_oapi.ws.exception import *
|
||
|
from lark_oapi.ws.model import *
|
||
|
from lark_oapi.ws.pb.google.protobuf.internal.containers import RepeatedCompositeFieldContainer
|
||
|
from lark_oapi.ws.pb.pbbp2_pb2 import Frame
|
||
|
|
||
|
try:
|
||
|
loop = asyncio.get_event_loop()
|
||
|
except RuntimeError:
|
||
|
loop = asyncio.new_event_loop()
|
||
|
asyncio.set_event_loop(loop)
|
||
|
|
||
|
|
||
|
def _get_by_key(headers: RepeatedCompositeFieldContainer, key: str) -> str:
|
||
|
for header in headers:
|
||
|
if header.key == key:
|
||
|
return header.value
|
||
|
|
||
|
raise HeaderNotFoundException(key)
|
||
|
|
||
|
|
||
|
def _new_ping_frame(service_id: int) -> Frame:
|
||
|
frame = Frame()
|
||
|
header = frame.headers.add()
|
||
|
header.key = HEADER_TYPE
|
||
|
header.value = MessageType.PING.value
|
||
|
frame.service = service_id
|
||
|
frame.method = FrameType.CONTROL.value
|
||
|
frame.SeqID = 0
|
||
|
frame.LogID = 0
|
||
|
|
||
|
return frame
|
||
|
|
||
|
|
||
|
def _ordinal(n: int):
|
||
|
suffixes = {1: 'st', 2: 'nd', 3: 'rd'}
|
||
|
if 10 <= n <= 20:
|
||
|
suffix = 'th'
|
||
|
else:
|
||
|
suffix = suffixes.get(n % 10, 'th')
|
||
|
return str(n) + suffix
|
||
|
|
||
|
|
||
|
async def _select():
|
||
|
while True:
|
||
|
await asyncio.sleep(3600)
|
||
|
|
||
|
|
||
|
def _parse_ws_conn_exception(e: websockets.InvalidStatusCode):
|
||
|
code = e.headers.get(HEADER_HANDSHAKE_STATUS)
|
||
|
msg = e.headers.get(HEADER_HANDSHAKE_MSG)
|
||
|
if code is None or msg is None:
|
||
|
raise e
|
||
|
|
||
|
code = int(code)
|
||
|
if code == AUTH_FAILED:
|
||
|
auth_code = e.headers.get(HEADER_HANDSHAKE_AUTH_ERRCODE)
|
||
|
if int(auth_code) == EXCEED_CONN_LIMIT:
|
||
|
raise ClientException(code, msg)
|
||
|
else:
|
||
|
raise ServerException(code, msg)
|
||
|
elif code == FORBIDDEN:
|
||
|
raise ClientException(code, msg)
|
||
|
else:
|
||
|
raise ServerException(code, msg)
|
||
|
|
||
|
|
||
|
class Client(object):
|
||
|
def __init__(self,
|
||
|
app_id: str,
|
||
|
app_secret,
|
||
|
log_level: LogLevel = LogLevel.INFO,
|
||
|
event_handler: EventDispatcherHandler = None,
|
||
|
domain: str = FEISHU_DOMAIN,
|
||
|
auto_reconnect: bool = True) -> None:
|
||
|
self._app_id: str = app_id
|
||
|
self._app_secret: str = app_secret
|
||
|
self._log_level: LogLevel = log_level
|
||
|
self._event_handler: EventDispatcherHandler = event_handler
|
||
|
self._auto_reconnect: bool = auto_reconnect
|
||
|
self._domain: str = domain
|
||
|
self._conn: Optional[websockets.WebSocketClientProtocol] = None
|
||
|
self._conn_url: str = ""
|
||
|
self._service_id: str = ""
|
||
|
self._conn_id: str = ""
|
||
|
self._reconnect_nonce: int = 30
|
||
|
self._reconnect_count: int = -1
|
||
|
self._reconnect_interval: int = 120
|
||
|
self._ping_interval: int = 120
|
||
|
self._cache: ExpiringCache = ExpiringCache(clear_interval=30)
|
||
|
self._lock = asyncio.Lock()
|
||
|
logger.setLevel(log_level.value)
|
||
|
|
||
|
def start(self) -> None:
|
||
|
try:
|
||
|
loop.run_until_complete(self._connect())
|
||
|
except ClientException as e:
|
||
|
logger.error(self._fmt_log("connect failed, err: {}", e))
|
||
|
raise e
|
||
|
except Exception as e:
|
||
|
logger.error(self._fmt_log("connect failed, err: {}", e))
|
||
|
loop.run_until_complete(self._disconnect())
|
||
|
if self._auto_reconnect:
|
||
|
loop.run_until_complete(self._reconnect())
|
||
|
else:
|
||
|
raise e
|
||
|
|
||
|
loop.create_task(self._ping_loop())
|
||
|
loop.run_until_complete(_select())
|
||
|
|
||
|
async def _ping_loop(self):
|
||
|
while True:
|
||
|
try:
|
||
|
if self._conn is not None:
|
||
|
frame = _new_ping_frame(int(self._service_id))
|
||
|
await self._write_message(frame.SerializeToString())
|
||
|
logger.debug(self._fmt_log("ping success"))
|
||
|
except Exception as e:
|
||
|
logger.warn(self._fmt_log("ping failed, err: {}", e))
|
||
|
finally:
|
||
|
await asyncio.sleep(self._ping_interval)
|
||
|
|
||
|
async def _connect(self) -> None:
|
||
|
await self._lock.acquire()
|
||
|
if self._conn is not None:
|
||
|
return
|
||
|
try:
|
||
|
conn_url = self._get_conn_url()
|
||
|
u = urlparse(conn_url)
|
||
|
q = parse_qs(u.query)
|
||
|
conn_id = q[DEVICE_ID][0]
|
||
|
service_id = q[SERVICE_ID][0]
|
||
|
|
||
|
conn = await websockets.connect(conn_url)
|
||
|
self._conn = conn
|
||
|
self._conn_url = conn_url
|
||
|
self._conn_id = conn_id
|
||
|
self._service_id = service_id
|
||
|
|
||
|
logger.info(self._fmt_log("connected to {}", conn_url))
|
||
|
loop.create_task(self._receive_message_loop())
|
||
|
except websockets.InvalidStatusCode as e:
|
||
|
_parse_ws_conn_exception(e)
|
||
|
finally:
|
||
|
self._lock.release()
|
||
|
|
||
|
async def _receive_message_loop(self):
|
||
|
try:
|
||
|
while True:
|
||
|
if self._conn is None:
|
||
|
raise ConnectionClosedException("connection is closed")
|
||
|
msg = await self._conn.recv()
|
||
|
loop.create_task(self._handle_message(msg))
|
||
|
except Exception as e:
|
||
|
logger.error(self._fmt_log("receive message loop exit, err: {}", e))
|
||
|
await self._disconnect()
|
||
|
if self._auto_reconnect:
|
||
|
await self._reconnect()
|
||
|
else:
|
||
|
raise e
|
||
|
|
||
|
def _get_conn_url(self) -> str:
|
||
|
if Strings.is_empty(self._app_id) or Strings.is_empty(self._app_secret):
|
||
|
raise ClientException(NO_CREDENTIAL, "app_id or app_secret is null")
|
||
|
|
||
|
response = requests.post(
|
||
|
self._domain + GEN_ENDPOINT_URI,
|
||
|
headers={
|
||
|
"locale": "zh",
|
||
|
},
|
||
|
json={
|
||
|
"AppID": self._app_id,
|
||
|
"AppSecret": self._app_secret,
|
||
|
},
|
||
|
)
|
||
|
if response.status_code != http.HTTPStatus.OK:
|
||
|
raise ServerException(response.status_code, "system busy")
|
||
|
|
||
|
resp = JSON.unmarshal(str(response.content, UTF_8), EndpointResp)
|
||
|
if resp.code == OK:
|
||
|
pass
|
||
|
elif resp.code == SYSTEM_BUSY:
|
||
|
raise ServerException(resp.code, "system busy")
|
||
|
elif resp.code == INTERNAL_ERROR:
|
||
|
raise ServerException(resp.code, resp.msg)
|
||
|
else:
|
||
|
raise ClientException(resp.code, resp.msg)
|
||
|
|
||
|
data = resp.data
|
||
|
if data.ClientConfig is not None:
|
||
|
self._configure(data.ClientConfig)
|
||
|
|
||
|
return data.URL
|
||
|
|
||
|
async def _handle_message(self, msg: bytes) -> None:
|
||
|
try:
|
||
|
frame = Frame()
|
||
|
frame.ParseFromString(msg)
|
||
|
ft = FrameType(frame.method)
|
||
|
|
||
|
if ft == FrameType.CONTROL:
|
||
|
await self._handle_control_frame(frame)
|
||
|
elif ft == FrameType.DATA:
|
||
|
await self._handle_data_frame(frame)
|
||
|
except Exception as e:
|
||
|
logger.error(self._fmt_log("handle message failed, err: {}", e))
|
||
|
|
||
|
async def _handle_control_frame(self, frame: Frame):
|
||
|
hs = frame.headers
|
||
|
type_ = _get_by_key(hs, HEADER_TYPE)
|
||
|
message_type = MessageType(type_)
|
||
|
|
||
|
if message_type == MessageType.PING:
|
||
|
return
|
||
|
elif message_type == MessageType.PONG:
|
||
|
logger.debug(self._fmt_log("receive pong"))
|
||
|
if not frame.payload:
|
||
|
return
|
||
|
conf = JSON.unmarshal(str(frame.payload, UTF_8), ClientConfig)
|
||
|
self._configure(conf)
|
||
|
|
||
|
async def _handle_data_frame(self, frame: Frame):
|
||
|
hs = frame.headers
|
||
|
msg_id = _get_by_key(hs, HEADER_MESSAGE_ID)
|
||
|
trace_id = _get_by_key(hs, HEADER_TRACE_ID)
|
||
|
sum_ = _get_by_key(hs, HEADER_SUM)
|
||
|
seq = _get_by_key(hs, HEADER_SEQ)
|
||
|
type_ = _get_by_key(hs, HEADER_TYPE)
|
||
|
|
||
|
pl = frame.payload
|
||
|
if int(sum_) > 1:
|
||
|
# 合包
|
||
|
pl = self._combine(msg_id, int(sum_), int(seq), pl)
|
||
|
if pl is None:
|
||
|
return
|
||
|
|
||
|
message_type = MessageType(type_)
|
||
|
logger.debug(self._fmt_log("receive message, message_type: {}, message_id: {}, trace_id: {}, payload: {}",
|
||
|
message_type.value, msg_id, trace_id, pl.decode(UTF_8)))
|
||
|
|
||
|
resp = Response(code=http.HTTPStatus.OK)
|
||
|
try:
|
||
|
start = int(round(time.time() * 1000))
|
||
|
if message_type == MessageType.EVENT:
|
||
|
result = self._event_handler.do_without_validation(pl)
|
||
|
elif message_type == MessageType.CARD:
|
||
|
return
|
||
|
else:
|
||
|
return
|
||
|
end = int(round(time.time() * 1000))
|
||
|
header = hs.add()
|
||
|
header.key = HEADER_BIZ_RT
|
||
|
header.value = str(end - start)
|
||
|
if result is not None:
|
||
|
resp.data = base64.b64encode(JSON.marshal(result).encode(UTF_8))
|
||
|
except Exception as e:
|
||
|
logger.error(
|
||
|
self._fmt_log("handle message failed, message_type: {}, message_id: {}, trace_id: {}, err: {}",
|
||
|
message_type.value, msg_id, trace_id, e))
|
||
|
resp = Response(code=http.HTTPStatus.INTERNAL_SERVER_ERROR)
|
||
|
|
||
|
frame.payload = JSON.marshal(resp).encode(UTF_8)
|
||
|
await self._write_message(frame.SerializeToString())
|
||
|
|
||
|
async def _reconnect(self):
|
||
|
# 首次重连随机抖动
|
||
|
if self._reconnect_nonce > 0:
|
||
|
nonce = random.random() * self._reconnect_nonce
|
||
|
await asyncio.sleep(nonce)
|
||
|
|
||
|
# 重连
|
||
|
if self._reconnect_count >= 0:
|
||
|
for i in range(self._reconnect_count):
|
||
|
if await self._try_connect(i):
|
||
|
return
|
||
|
await asyncio.sleep(self._reconnect_interval)
|
||
|
raise ServerUnreachableException(
|
||
|
f"unable to connect to the server after trying {self._reconnect_count} times")
|
||
|
else:
|
||
|
i = 0
|
||
|
while True:
|
||
|
if await self._try_connect(i):
|
||
|
return
|
||
|
await asyncio.sleep(self._reconnect_interval)
|
||
|
i += 1
|
||
|
|
||
|
async def _try_connect(self, cnt: int) -> bool:
|
||
|
logger.info(self._fmt_log("trying to reconnect for the {} time", _ordinal(cnt + 1)))
|
||
|
try:
|
||
|
await self._connect()
|
||
|
return True
|
||
|
except ClientException as e:
|
||
|
logger.error(self._fmt_log("connect failed, err: {}", e))
|
||
|
raise e
|
||
|
except Exception as e:
|
||
|
logger.error(self._fmt_log("connect failed, err: {}", e))
|
||
|
return False
|
||
|
|
||
|
async def _disconnect(self):
|
||
|
try:
|
||
|
await self._lock.acquire()
|
||
|
if self._conn is None:
|
||
|
return
|
||
|
await self._conn.close()
|
||
|
logger.info(self._fmt_log("disconnected to {}", self._conn_url))
|
||
|
finally:
|
||
|
self._conn = None
|
||
|
self._conn_url = ""
|
||
|
self._conn_id = ""
|
||
|
self._service_id = ""
|
||
|
self._lock.release()
|
||
|
|
||
|
async def _write_message(self, data: bytes):
|
||
|
async with self._lock:
|
||
|
if self._conn is None:
|
||
|
raise ConnectionClosedException("connection is closed, write message failed")
|
||
|
await self._conn.send(data)
|
||
|
|
||
|
def _combine(self, msg_id: str, sum_: int, seq: int, bs: bytes) -> Optional[bytes]:
|
||
|
val = self._cache.get(msg_id)
|
||
|
if val is None:
|
||
|
buf = [b''] * sum_
|
||
|
buf[seq] = bs
|
||
|
self._cache.set(msg_id, buf, 5)
|
||
|
return None
|
||
|
|
||
|
val[seq] = bs
|
||
|
pl = b''
|
||
|
for v in val:
|
||
|
if not v:
|
||
|
self._cache.set(msg_id, val, 5)
|
||
|
return None
|
||
|
pl += v
|
||
|
|
||
|
return pl
|
||
|
|
||
|
def _configure(self, conf: ClientConfig) -> None:
|
||
|
self._reconnect_count = conf.ReconnectCount
|
||
|
self._reconnect_interval = conf.ReconnectInterval
|
||
|
self._reconnect_nonce = conf.ReconnectNonce
|
||
|
self._ping_interval = conf.PingInterval
|
||
|
|
||
|
def _fmt_log(self, fmt: str, *args) -> str:
|
||
|
log = fmt.format(*args)
|
||
|
if self._conn_id != "":
|
||
|
log += f' [conn_id={self._conn_id}]'
|
||
|
|
||
|
return log
|