sati-go/socket.go

286 lines
5.3 KiB
Go

package sati
import (
"fmt"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/mitchellh/mapstructure"
)
const (
stReconnecting uint32 = iota
stConnected
stUnrecoverable
)
type message struct {
Type string
Id uint32
Data any
To uint32
}
type outgoingMessage struct {
msg message
result chan *incomingMessage
}
type incomingMessage struct {
msg *message
err error
}
type socket struct {
idCounter uint32
state uint32
unrecoverableError error
closeNotifier chan struct{}
outgoing chan *outgoingMessage
ws *websocket.Conn
awaitedReplies map[uint32]chan *incomingMessage
mu *sync.Mutex
events *EventBus
config Config
}
func (s *socket) reciever() {
for {
var message message
if err := s.ws.ReadJSON(&message); err != nil {
if s.config.Debug {
fmt.Println("sati: got error while reading socket", err.Error())
}
s.mu.Lock()
for _, ch := range s.awaitedReplies {
ch <- &incomingMessage{
err: s.unrecoverableError,
}
}
s.awaitedReplies = make(map[uint32]chan *incomingMessage)
s.mu.Unlock()
s.closeNotifier <- struct{}{}
return
}
if s.config.Debug {
fmt.Println("sati: recieved message", &message)
}
switch message.Type {
case "auth":
fallthrough
case "call":
if message.To == 0 {
continue
}
s.mu.Lock()
resultCh, ok := s.awaitedReplies[message.To]
if ok {
delete(s.awaitedReplies, message.To)
}
s.mu.Unlock()
if resultCh != nil {
resultCh <- &incomingMessage{msg: &message}
}
case "event":
var event struct {
Type string `json:"type"`
Data any `json:"data"`
}
mapstructure.Decode(message.Data, &event)
if err := s.events.dispatch(event.Type, event.Data); err != nil && s.config.Debug {
fmt.Println("sati: error while dispatching event:", err.Error())
}
}
}
}
func (s *socket) send(msg *outgoingMessage) error {
s.idCounter++
msg.msg.Id = s.idCounter
if msg.result != nil {
s.mu.Lock()
s.awaitedReplies[msg.msg.Id] = msg.result
s.mu.Unlock()
}
if s.config.Debug {
fmt.Println("sati: sending message", msg)
}
err := s.ws.WriteJSON(msg.msg)
if msg.result != nil && err != nil {
s.mu.Lock()
s.awaitedReplies[msg.msg.Id] <- &incomingMessage{
err: err,
}
delete(s.awaitedReplies, msg.msg.Id)
s.mu.Unlock()
}
return err
}
func (s *socket) sender() {
for {
select {
case msg := <-s.outgoing:
s.send(msg)
case <-s.closeNotifier:
return
}
}
}
func (s *socket) connect() error {
if s.config.Debug {
fmt.Println("sati: connecting")
}
ws, _, err := websocket.DefaultDialer.Dial(s.config.Endpoint, http.Header{})
if err != nil {
return err
}
s.mu.Lock()
s.state = stReconnecting
s.ws = ws
s.mu.Unlock()
resultChan := make(chan *incomingMessage)
s.send(&outgoingMessage{
message{
Type: "auth",
Data: struct {
Token string `json:"token"`
}{s.config.Token},
}, resultChan,
})
go s.reciever()
rawResult := <-resultChan
if rawResult.err != nil {
return rawResult.err
}
var result struct {
Success bool `json:"success"`
}
if err := mapstructure.Decode(rawResult.msg.Data, &result); err != nil {
return err
}
if !result.Success {
s.setUnrecoverableState("invalid auth token")
return s.unrecoverableError
}
s.sender()
return nil
}
func (s *socket) connector() {
for {
s.mu.Lock()
state := s.state
s.mu.Unlock()
if state == stUnrecoverable {
break
}
err := s.connect() // will block until disconnect
if s.config.Debug && err != nil {
fmt.Println("sati: disconnected", err.Error())
}
time.Sleep(s.config.ReconnectionInterval)
}
}
func (s *socket) setUnrecoverableState(err string) {
s.mu.Lock()
defer s.mu.Unlock()
if s.state == stUnrecoverable {
return
}
s.unrecoverableError = fmt.Errorf("sati: %s", err)
s.state = stUnrecoverable
s.ws.Close()
}
func newSocket(config Config) *socket {
s := &socket{
closeNotifier: make(chan struct{}),
outgoing: make(chan *outgoingMessage),
awaitedReplies: make(map[uint32]chan *incomingMessage),
mu: &sync.Mutex{},
events: newEventBus(map[string]any{
"taskUpdate": TaskUpdateEvent{},
"tokenReissue": TokenReissueEvent{},
}),
config: config,
}
s.events.On("tokenReissue", func(any) {
s.setUnrecoverableState("token was reissued")
})
go s.connector()
return s
}
func (s *socket) close() {
s.setUnrecoverableState("socket closed")
}
func (s *socket) call(method string, data any, result any) error {
s.mu.Lock()
if s.state == stUnrecoverable {
err := s.unrecoverableError
s.mu.Unlock()
return err
}
s.mu.Unlock()
resultCh := make(chan *incomingMessage)
s.outgoing <- &outgoingMessage{
msg: message{
Type: "call",
Data: CallMessageOutgoing{
Method: method,
Data: data,
},
},
result: resultCh,
}
resultMsg := <-resultCh
if resultMsg.err != nil {
return resultMsg.err
}
callResult := CallMessageIncoming{}
if err := mapstructure.Decode(resultMsg.msg.Data, &callResult); err != nil {
return err
}
if !callResult.Success {
callErr := &CallError{}
if err := mapstructure.Decode(callResult.Data, callErr); err != nil {
return err
}
return callErr
}
if err := mapstructure.Decode(callResult.Data, result); err != nil {
return err
}
return nil
}