initial commit
This commit is contained in:
285
socket.go
Normal file
285
socket.go
Normal file
@ -0,0 +1,285 @@
|
||||
package sati
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
const (
|
||||
stReconnecting uint32 = iota
|
||||
stConnected
|
||||
stUnrecoverable
|
||||
)
|
||||
|
||||
type message struct {
|
||||
Type string
|
||||
Id uint32
|
||||
Data any
|
||||
To uint32
|
||||
}
|
||||
|
||||
type outgoingMessage struct {
|
||||
msg message
|
||||
result chan *incomingMessage
|
||||
}
|
||||
|
||||
type incomingMessage struct {
|
||||
msg *message
|
||||
err error
|
||||
}
|
||||
|
||||
type socket struct {
|
||||
idCounter uint32
|
||||
state uint32
|
||||
unrecoverableError error
|
||||
closeNotifier chan struct{}
|
||||
outgoing chan *outgoingMessage
|
||||
ws *websocket.Conn
|
||||
awaitedReplies map[uint32]chan *incomingMessage
|
||||
mu *sync.Mutex
|
||||
events *EventBus
|
||||
config Config
|
||||
}
|
||||
|
||||
func (s *socket) reciever() {
|
||||
for {
|
||||
var message message
|
||||
if err := s.ws.ReadJSON(&message); err != nil {
|
||||
if s.config.Debug {
|
||||
fmt.Println("sati: got error while reading socket", err.Error())
|
||||
}
|
||||
s.mu.Lock()
|
||||
for _, ch := range s.awaitedReplies {
|
||||
ch <- &incomingMessage{
|
||||
err: s.unrecoverableError,
|
||||
}
|
||||
}
|
||||
s.awaitedReplies = make(map[uint32]chan *incomingMessage)
|
||||
s.mu.Unlock()
|
||||
s.closeNotifier <- struct{}{}
|
||||
return
|
||||
}
|
||||
|
||||
if s.config.Debug {
|
||||
fmt.Println("sati: recieved message", &message)
|
||||
}
|
||||
|
||||
switch message.Type {
|
||||
case "auth":
|
||||
fallthrough
|
||||
case "call":
|
||||
if message.To == 0 {
|
||||
continue
|
||||
}
|
||||
s.mu.Lock()
|
||||
resultCh, ok := s.awaitedReplies[message.To]
|
||||
if ok {
|
||||
delete(s.awaitedReplies, message.To)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
if resultCh != nil {
|
||||
resultCh <- &incomingMessage{msg: &message}
|
||||
}
|
||||
case "event":
|
||||
var event struct {
|
||||
Type string `json:"type"`
|
||||
Data any `json:"data"`
|
||||
}
|
||||
|
||||
mapstructure.Decode(message.Data, &event)
|
||||
if err := s.events.dispatch(event.Type, event.Data); err != nil && s.config.Debug {
|
||||
fmt.Println("sati: error while dispatching event:", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *socket) send(msg *outgoingMessage) error {
|
||||
s.idCounter++
|
||||
msg.msg.Id = s.idCounter
|
||||
if msg.result != nil {
|
||||
s.mu.Lock()
|
||||
s.awaitedReplies[msg.msg.Id] = msg.result
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
if s.config.Debug {
|
||||
fmt.Println("sati: sending message", msg)
|
||||
}
|
||||
|
||||
err := s.ws.WriteJSON(msg.msg)
|
||||
if msg.result != nil && err != nil {
|
||||
s.mu.Lock()
|
||||
s.awaitedReplies[msg.msg.Id] <- &incomingMessage{
|
||||
err: err,
|
||||
}
|
||||
delete(s.awaitedReplies, msg.msg.Id)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *socket) sender() {
|
||||
for {
|
||||
select {
|
||||
case msg := <-s.outgoing:
|
||||
s.send(msg)
|
||||
case <-s.closeNotifier:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *socket) connect() error {
|
||||
if s.config.Debug {
|
||||
fmt.Println("sati: connecting")
|
||||
}
|
||||
ws, _, err := websocket.DefaultDialer.Dial(s.config.Endpoint, http.Header{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.state = stReconnecting
|
||||
s.ws = ws
|
||||
s.mu.Unlock()
|
||||
|
||||
resultChan := make(chan *incomingMessage)
|
||||
s.send(&outgoingMessage{
|
||||
message{
|
||||
Type: "auth",
|
||||
Data: struct {
|
||||
Token string `json:"token"`
|
||||
}{s.config.Token},
|
||||
}, resultChan,
|
||||
})
|
||||
|
||||
go s.reciever()
|
||||
|
||||
rawResult := <-resultChan
|
||||
if rawResult.err != nil {
|
||||
return rawResult.err
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
if err := mapstructure.Decode(rawResult.msg.Data, &result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
s.setUnrecoverableState("invalid auth token")
|
||||
return s.unrecoverableError
|
||||
}
|
||||
|
||||
s.sender()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *socket) connector() {
|
||||
for {
|
||||
s.mu.Lock()
|
||||
state := s.state
|
||||
s.mu.Unlock()
|
||||
if state == stUnrecoverable {
|
||||
break
|
||||
}
|
||||
|
||||
err := s.connect() // will block until disconnect
|
||||
if s.config.Debug && err != nil {
|
||||
fmt.Println("sati: disconnected", err.Error())
|
||||
}
|
||||
time.Sleep(s.config.ReconnectionInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *socket) setUnrecoverableState(err string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.state == stUnrecoverable {
|
||||
return
|
||||
}
|
||||
s.unrecoverableError = fmt.Errorf("sati: %s", err)
|
||||
s.state = stUnrecoverable
|
||||
s.ws.Close()
|
||||
}
|
||||
|
||||
func newSocket(config Config) *socket {
|
||||
s := &socket{
|
||||
closeNotifier: make(chan struct{}),
|
||||
outgoing: make(chan *outgoingMessage),
|
||||
awaitedReplies: make(map[uint32]chan *incomingMessage),
|
||||
mu: &sync.Mutex{},
|
||||
events: newEventBus(map[string]any{
|
||||
"taskUpdate": TaskUpdateEvent{},
|
||||
"tokenReissue": TokenReissueEvent{},
|
||||
}),
|
||||
config: config,
|
||||
}
|
||||
|
||||
s.events.On("tokenReissue", func(any) {
|
||||
s.setUnrecoverableState("token was reissued")
|
||||
})
|
||||
|
||||
go s.connector()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *socket) close() {
|
||||
s.setUnrecoverableState("socket closed")
|
||||
}
|
||||
|
||||
func (s *socket) call(method string, data any, result any) error {
|
||||
s.mu.Lock()
|
||||
if s.state == stUnrecoverable {
|
||||
err := s.unrecoverableError
|
||||
s.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
resultCh := make(chan *incomingMessage)
|
||||
s.outgoing <- &outgoingMessage{
|
||||
msg: message{
|
||||
Type: "call",
|
||||
Data: CallMessageOutgoing{
|
||||
Method: method,
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
result: resultCh,
|
||||
}
|
||||
|
||||
resultMsg := <-resultCh
|
||||
if resultMsg.err != nil {
|
||||
return resultMsg.err
|
||||
}
|
||||
|
||||
callResult := CallMessageIncoming{}
|
||||
if err := mapstructure.Decode(resultMsg.msg.Data, &callResult); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !callResult.Success {
|
||||
callErr := &CallError{}
|
||||
if err := mapstructure.Decode(callResult.Data, callErr); err != nil {
|
||||
return err
|
||||
}
|
||||
return callErr
|
||||
}
|
||||
|
||||
if err := mapstructure.Decode(callResult.Data, result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user