package sati import ( "fmt" "net/http" "sync" "time" "github.com/gorilla/websocket" "github.com/mitchellh/mapstructure" ) const ( stReconnecting uint32 = iota stConnected stUnrecoverable ) type message struct { Type string Id uint32 Data any To uint32 } type outgoingMessage struct { msg message result chan *incomingMessage } type incomingMessage struct { msg *message err error } type socket struct { idCounter uint32 state uint32 unrecoverableError error closeNotifier chan struct{} outgoing chan *outgoingMessage ws *websocket.Conn awaitedReplies map[uint32]chan *incomingMessage mu *sync.Mutex events *EventBus config Config } func (s *socket) reciever() { for { var message message if err := s.ws.ReadJSON(&message); err != nil { if s.config.Debug { fmt.Println("sati: got error while reading socket", err.Error()) } s.mu.Lock() for _, ch := range s.awaitedReplies { ch <- &incomingMessage{ err: s.unrecoverableError, } } s.awaitedReplies = make(map[uint32]chan *incomingMessage) s.mu.Unlock() s.closeNotifier <- struct{}{} return } if s.config.Debug { fmt.Println("sati: recieved message", &message) } switch message.Type { case "auth": fallthrough case "call": if message.To == 0 { continue } s.mu.Lock() resultCh, ok := s.awaitedReplies[message.To] if ok { delete(s.awaitedReplies, message.To) } s.mu.Unlock() if resultCh != nil { resultCh <- &incomingMessage{msg: &message} } case "event": var event struct { Type string `json:"type"` Data any `json:"data"` } mapstructure.Decode(message.Data, &event) if err := s.events.dispatch(event.Type, event.Data); err != nil && s.config.Debug { fmt.Println("sati: error while dispatching event:", err.Error()) } } } } func (s *socket) send(msg *outgoingMessage) error { s.idCounter++ msg.msg.Id = s.idCounter if msg.result != nil { s.mu.Lock() s.awaitedReplies[msg.msg.Id] = msg.result s.mu.Unlock() } if s.config.Debug { fmt.Println("sati: sending message", msg) } err := s.ws.WriteJSON(msg.msg) if msg.result != nil && err != nil { s.mu.Lock() s.awaitedReplies[msg.msg.Id] <- &incomingMessage{ err: err, } delete(s.awaitedReplies, msg.msg.Id) s.mu.Unlock() } return err } func (s *socket) sender() { for { select { case msg := <-s.outgoing: s.send(msg) case <-s.closeNotifier: return } } } func (s *socket) connect() error { if s.config.Debug { fmt.Println("sati: connecting") } ws, _, err := websocket.DefaultDialer.Dial(s.config.Endpoint, http.Header{}) if err != nil { return err } s.mu.Lock() s.state = stReconnecting s.ws = ws s.mu.Unlock() resultChan := make(chan *incomingMessage) s.send(&outgoingMessage{ message{ Type: "auth", Data: struct { Token string `json:"token"` }{s.config.Token}, }, resultChan, }) go s.reciever() rawResult := <-resultChan if rawResult.err != nil { return rawResult.err } var result struct { Success bool `json:"success"` } if err := mapstructure.Decode(rawResult.msg.Data, &result); err != nil { return err } if !result.Success { s.setUnrecoverableState("invalid auth token") return s.unrecoverableError } s.sender() return nil } func (s *socket) connector() { for { s.mu.Lock() state := s.state s.mu.Unlock() if state == stUnrecoverable { break } err := s.connect() // will block until disconnect if s.config.Debug && err != nil { fmt.Println("sati: disconnected", err.Error()) } time.Sleep(s.config.ReconnectionInterval) } } func (s *socket) setUnrecoverableState(err string) { s.mu.Lock() defer s.mu.Unlock() if s.state == stUnrecoverable { return } s.unrecoverableError = fmt.Errorf("sati: %s", err) s.state = stUnrecoverable s.ws.Close() } func newSocket(config Config) *socket { s := &socket{ closeNotifier: make(chan struct{}), outgoing: make(chan *outgoingMessage), awaitedReplies: make(map[uint32]chan *incomingMessage), mu: &sync.Mutex{}, events: newEventBus(map[string]any{ "taskUpdate": TaskUpdateEvent{}, "tokenReissue": TokenReissueEvent{}, }), config: config, } s.events.On("tokenReissue", func(any) { s.setUnrecoverableState("token was reissued") }) go s.connector() return s } func (s *socket) close() { s.setUnrecoverableState("socket closed") } func (s *socket) call(method string, data any, result any) error { s.mu.Lock() if s.state == stUnrecoverable { err := s.unrecoverableError s.mu.Unlock() return err } s.mu.Unlock() resultCh := make(chan *incomingMessage) s.outgoing <- &outgoingMessage{ msg: message{ Type: "call", Data: CallMessageOutgoing{ Method: method, Data: data, }, }, result: resultCh, } resultMsg := <-resultCh if resultMsg.err != nil { return resultMsg.err } callResult := CallMessageIncoming{} if err := mapstructure.Decode(resultMsg.msg.Data, &callResult); err != nil { return err } if !callResult.Success { callErr := &CallError{} if err := mapstructure.Decode(callResult.Data, callErr); err != nil { return err } return callErr } if err := mapstructure.Decode(callResult.Data, result); err != nil { return err } return nil }