sati-py/sati/socket.py

158 lines
5.0 KiB
Python

import json
import typing
from dataclasses import dataclass
import asyncio
import websockets.client as wsc
from .util import SatiDict
STATE_CONNECTED = 0
STATE_RECONNECTING = 1
STATE_UNRECOVERABLE = 2
class SatiUnrecoverableException(Exception):
def __init__(self, message: str):
super().__init__(f"sati: {message}")
class SatiException(Exception):
''' api error '''
def __init__(self, message: str, code: int = 0):
super().__init__(f"sati: #{code}: {message}")
self.code = code
@dataclass
class QueueEntry:
fut: asyncio.Future
method: str
data: dict
class SatiSocket:
''' low-level api wrapper '''
def __init__(
self,
token: str,
reconnection_interval: float = 1,
url = 'wss://api.sati.ac/ws',
debug = False
):
self.__awaited_replies = {}
self.__queue = []
self.__event_handlers = {}
self.__token = token
self.__reconnection_interval = reconnection_interval
self.__url = url
self.__connector_ref = asyncio.create_task(self.__connector())
self.__debug = debug
self.__error = None
self.__id_counter = 0
self.__state = STATE_RECONNECTING
async def __connector(self):
while self.__state != STATE_UNRECOVERABLE:
try:
try:
await self.__connect()
except asyncio.CancelledError as ex:
raise SatiUnrecoverableException('socket closed') from ex
except SatiUnrecoverableException as ex:
self.__state = STATE_UNRECOVERABLE
self.__error = ex
break
except Exception as ex:
print(ex)
await asyncio.sleep(self.__reconnection_interval)
async def __connect(self):
self.__socket = await wsc.connect(self.__url)
self.__state = STATE_RECONNECTING
reader = asyncio.create_task(self.__reader())
auth_resp = await self.__send('auth', { 'token': self.__token })
if not auth_resp.success:
ex = SatiUnrecoverableException('invalid token')
for entry in self.__queue:
entry.fut.set_exception(ex)
self.__queue = []
raise ex
self.__state = STATE_CONNECTED
for entry in self.__queue:
asyncio.create_task(self.__resend_call(entry))
self.__queue = []
await reader
async def __resend_call(self, call: QueueEntry):
try:
result = await self.call(call.method, call.data)
call.fut.set_result(result)
except Exception as ex:
call.fut.set_exception(ex)
async def __send(self, msg_type: str, data: dict) -> dict:
self.__id_counter += 1
msg_id = self.__id_counter
if self.__debug:
print(f'sending message {msg_type} with id {msg_id}', data)
if msg_type in ( 'auth', 'call' ):
fut = self.__awaited_replies[msg_id] = asyncio.Future()
await self.__socket.send(json.dumps({
'id': self.__id_counter,
'type': msg_type,
'data': data
}))
if msg_type in ( 'auth', 'call' ):
return await fut
async def call(self, method: str, data: dict = {}) -> SatiDict:
if self.__state == STATE_CONNECTED:
resp = await self.__send('call', {
'method': method,
'data': data
})
if not resp.success:
raise SatiException(resp.data.description, code=resp.data.code)
return resp.data
if self.__state == STATE_RECONNECTING:
fut = asyncio.Future()
self.__queue.append(QueueEntry(fut, method, data))
return await fut
if self.__state == STATE_UNRECOVERABLE:
raise self.__error
async def __reader(self):
try:
async for msg in self.__socket:
msg = SatiDict(json.loads(msg))
if self.__debug:
print('recieved message', msg)
if msg.type in [ 'auth', 'call' ] and msg.to in self.__awaited_replies:
self.__awaited_replies[msg.to].set_result(msg.data)
elif msg.type == 'event':
if msg.data.type not in self.__event_handlers:
continue
for handler in self.__event_handlers[msg.data.type]:
handler(msg.data.data)
except Exception as ex:
for key, reply in self.__awaited_replies.items():
reply.set_exception(ex)
del self.__awaited_replies[key]
raise ex
async def close(self):
self.__connector_ref.cancel()
await self.__socket.close()
def on(self, event: str, handler: typing.Callable[[SatiDict], None]):
if event not in self.__event_handlers:
self.__event_handlers[event] = []
self.__event_handlers[event].append(handler)