286 lines
5.3 KiB
Go
286 lines
5.3 KiB
Go
|
package sati
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/gorilla/websocket"
|
||
|
"github.com/mitchellh/mapstructure"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
stReconnecting uint32 = iota
|
||
|
stConnected
|
||
|
stUnrecoverable
|
||
|
)
|
||
|
|
||
|
type message struct {
|
||
|
Type string
|
||
|
Id uint32
|
||
|
Data any
|
||
|
To uint32
|
||
|
}
|
||
|
|
||
|
type outgoingMessage struct {
|
||
|
msg message
|
||
|
result chan *incomingMessage
|
||
|
}
|
||
|
|
||
|
type incomingMessage struct {
|
||
|
msg *message
|
||
|
err error
|
||
|
}
|
||
|
|
||
|
type socket struct {
|
||
|
idCounter uint32
|
||
|
state uint32
|
||
|
unrecoverableError error
|
||
|
closeNotifier chan struct{}
|
||
|
outgoing chan *outgoingMessage
|
||
|
ws *websocket.Conn
|
||
|
awaitedReplies map[uint32]chan *incomingMessage
|
||
|
mu *sync.Mutex
|
||
|
events *EventBus
|
||
|
config Config
|
||
|
}
|
||
|
|
||
|
func (s *socket) reciever() {
|
||
|
for {
|
||
|
var message message
|
||
|
if err := s.ws.ReadJSON(&message); err != nil {
|
||
|
if s.config.Debug {
|
||
|
fmt.Println("sati: got error while reading socket", err.Error())
|
||
|
}
|
||
|
s.mu.Lock()
|
||
|
for _, ch := range s.awaitedReplies {
|
||
|
ch <- &incomingMessage{
|
||
|
err: s.unrecoverableError,
|
||
|
}
|
||
|
}
|
||
|
s.awaitedReplies = make(map[uint32]chan *incomingMessage)
|
||
|
s.mu.Unlock()
|
||
|
s.closeNotifier <- struct{}{}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if s.config.Debug {
|
||
|
fmt.Println("sati: recieved message", &message)
|
||
|
}
|
||
|
|
||
|
switch message.Type {
|
||
|
case "auth":
|
||
|
fallthrough
|
||
|
case "call":
|
||
|
if message.To == 0 {
|
||
|
continue
|
||
|
}
|
||
|
s.mu.Lock()
|
||
|
resultCh, ok := s.awaitedReplies[message.To]
|
||
|
if ok {
|
||
|
delete(s.awaitedReplies, message.To)
|
||
|
}
|
||
|
s.mu.Unlock()
|
||
|
if resultCh != nil {
|
||
|
resultCh <- &incomingMessage{msg: &message}
|
||
|
}
|
||
|
case "event":
|
||
|
var event struct {
|
||
|
Type string `json:"type"`
|
||
|
Data any `json:"data"`
|
||
|
}
|
||
|
|
||
|
mapstructure.Decode(message.Data, &event)
|
||
|
if err := s.events.dispatch(event.Type, event.Data); err != nil && s.config.Debug {
|
||
|
fmt.Println("sati: error while dispatching event:", err.Error())
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *socket) send(msg *outgoingMessage) error {
|
||
|
s.idCounter++
|
||
|
msg.msg.Id = s.idCounter
|
||
|
if msg.result != nil {
|
||
|
s.mu.Lock()
|
||
|
s.awaitedReplies[msg.msg.Id] = msg.result
|
||
|
s.mu.Unlock()
|
||
|
}
|
||
|
|
||
|
if s.config.Debug {
|
||
|
fmt.Println("sati: sending message", msg)
|
||
|
}
|
||
|
|
||
|
err := s.ws.WriteJSON(msg.msg)
|
||
|
if msg.result != nil && err != nil {
|
||
|
s.mu.Lock()
|
||
|
s.awaitedReplies[msg.msg.Id] <- &incomingMessage{
|
||
|
err: err,
|
||
|
}
|
||
|
delete(s.awaitedReplies, msg.msg.Id)
|
||
|
s.mu.Unlock()
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (s *socket) sender() {
|
||
|
for {
|
||
|
select {
|
||
|
case msg := <-s.outgoing:
|
||
|
s.send(msg)
|
||
|
case <-s.closeNotifier:
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *socket) connect() error {
|
||
|
if s.config.Debug {
|
||
|
fmt.Println("sati: connecting")
|
||
|
}
|
||
|
ws, _, err := websocket.DefaultDialer.Dial(s.config.Endpoint, http.Header{})
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
s.mu.Lock()
|
||
|
s.state = stReconnecting
|
||
|
s.ws = ws
|
||
|
s.mu.Unlock()
|
||
|
|
||
|
resultChan := make(chan *incomingMessage)
|
||
|
s.send(&outgoingMessage{
|
||
|
message{
|
||
|
Type: "auth",
|
||
|
Data: struct {
|
||
|
Token string `json:"token"`
|
||
|
}{s.config.Token},
|
||
|
}, resultChan,
|
||
|
})
|
||
|
|
||
|
go s.reciever()
|
||
|
|
||
|
rawResult := <-resultChan
|
||
|
if rawResult.err != nil {
|
||
|
return rawResult.err
|
||
|
}
|
||
|
|
||
|
var result struct {
|
||
|
Success bool `json:"success"`
|
||
|
}
|
||
|
|
||
|
if err := mapstructure.Decode(rawResult.msg.Data, &result); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if !result.Success {
|
||
|
s.setUnrecoverableState("invalid auth token")
|
||
|
return s.unrecoverableError
|
||
|
}
|
||
|
|
||
|
s.sender()
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *socket) connector() {
|
||
|
for {
|
||
|
s.mu.Lock()
|
||
|
state := s.state
|
||
|
s.mu.Unlock()
|
||
|
if state == stUnrecoverable {
|
||
|
break
|
||
|
}
|
||
|
|
||
|
err := s.connect() // will block until disconnect
|
||
|
if s.config.Debug && err != nil {
|
||
|
fmt.Println("sati: disconnected", err.Error())
|
||
|
}
|
||
|
time.Sleep(s.config.ReconnectionInterval)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *socket) setUnrecoverableState(err string) {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
if s.state == stUnrecoverable {
|
||
|
return
|
||
|
}
|
||
|
s.unrecoverableError = fmt.Errorf("sati: %s", err)
|
||
|
s.state = stUnrecoverable
|
||
|
s.ws.Close()
|
||
|
}
|
||
|
|
||
|
func newSocket(config Config) *socket {
|
||
|
s := &socket{
|
||
|
closeNotifier: make(chan struct{}),
|
||
|
outgoing: make(chan *outgoingMessage),
|
||
|
awaitedReplies: make(map[uint32]chan *incomingMessage),
|
||
|
mu: &sync.Mutex{},
|
||
|
events: newEventBus(map[string]any{
|
||
|
"taskUpdate": TaskUpdateEvent{},
|
||
|
"tokenReissue": TokenReissueEvent{},
|
||
|
}),
|
||
|
config: config,
|
||
|
}
|
||
|
|
||
|
s.events.On("tokenReissue", func(any) {
|
||
|
s.setUnrecoverableState("token was reissued")
|
||
|
})
|
||
|
|
||
|
go s.connector()
|
||
|
|
||
|
return s
|
||
|
}
|
||
|
|
||
|
func (s *socket) close() {
|
||
|
s.setUnrecoverableState("socket closed")
|
||
|
}
|
||
|
|
||
|
func (s *socket) call(method string, data any, result any) error {
|
||
|
s.mu.Lock()
|
||
|
if s.state == stUnrecoverable {
|
||
|
err := s.unrecoverableError
|
||
|
s.mu.Unlock()
|
||
|
return err
|
||
|
}
|
||
|
s.mu.Unlock()
|
||
|
|
||
|
resultCh := make(chan *incomingMessage)
|
||
|
s.outgoing <- &outgoingMessage{
|
||
|
msg: message{
|
||
|
Type: "call",
|
||
|
Data: CallMessageOutgoing{
|
||
|
Method: method,
|
||
|
Data: data,
|
||
|
},
|
||
|
},
|
||
|
result: resultCh,
|
||
|
}
|
||
|
|
||
|
resultMsg := <-resultCh
|
||
|
if resultMsg.err != nil {
|
||
|
return resultMsg.err
|
||
|
}
|
||
|
|
||
|
callResult := CallMessageIncoming{}
|
||
|
if err := mapstructure.Decode(resultMsg.msg.Data, &callResult); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if !callResult.Success {
|
||
|
callErr := &CallError{}
|
||
|
if err := mapstructure.Decode(callResult.Data, callErr); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return callErr
|
||
|
}
|
||
|
|
||
|
if err := mapstructure.Decode(callResult.Data, result); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|