bridge/main.go

333 lines
7.9 KiB
Go

package main
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"flag"
"fmt"
"math/big"
"net/http"
"os"
"os/signal"
"regexp"
"runtime"
"strings"
"time"
"git.sati.ac/sati.ac/bridge/api"
"git.sati.ac/sati.ac/bridge/config"
"git.sati.ac/sati.ac/sati-go"
"github.com/sirupsen/logrus"
)
type bridge struct {
config *config.Config
}
func issueCACert(certPath string, keyPath string) error {
cert := &x509.Certificate{
SerialNumber: big.NewInt(0),
Subject: pkix.Name{
CommonName: "Bridge CA",
Country: []string{"VA"},
Organization: []string{"sati.ac"},
Locality: []string{"Everywhere"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(100, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
keys, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
}
der, err := x509.CreateCertificate(rand.Reader, cert, cert, &keys.PublicKey, keys)
if err != nil {
return err
}
certFile, err := os.OpenFile(certPath, os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return err
}
defer certFile.Close()
if err := pem.Encode(certFile, &pem.Block{
Type: "CERTIFICATE",
Bytes: der,
}); err != nil {
return err
}
keyFile, err := os.OpenFile(keyPath, os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return err
}
defer keyFile.Close()
if err := pem.Encode(keyFile, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(keys),
}); err != nil {
return err
}
return nil
}
func issueCert(domains []string, caCertPath string, caKeyPath string) (string, string, error) {
cert := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().Unix()),
Subject: pkix.Name{
Country: []string{"VA"},
Organization: []string{"sati.ac"},
OrganizationalUnit: []string{"Bridge ephemeral certificate"},
Locality: []string{"Everywhere"},
},
DNSNames: domains,
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(100, 0, 0),
SubjectKeyId: []byte{0},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
keys, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return "", "", err
}
caCertPem, err := os.ReadFile(caCertPath)
if err != nil {
return "", "", err
}
block, _ := pem.Decode(caCertPem)
if block == nil || block.Type != "CERTIFICATE" {
return "", "", fmt.Errorf(`certificate: bad pem block "%s"`, block.Type)
}
caCert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return "", "", err
}
caKeyPem, err := os.ReadFile(caKeyPath)
if err != nil {
return "", "", err
}
block, _ = pem.Decode(caKeyPem)
if block == nil || block.Type != "RSA PRIVATE KEY" {
return "", "", fmt.Errorf(`key: bad pem block "%s"`, block.Type)
}
caKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return "", "", err
}
der, err := x509.CreateCertificate(rand.Reader, cert, caCert, &keys.PublicKey, caKey)
if err != nil {
return "", "", err
}
certFile, err := os.CreateTemp("", "bridge*.crt")
if err != nil {
return "", "", err
}
defer certFile.Close()
keyFile, err := os.CreateTemp("", "bridge*.key")
if err != nil {
return "", "", err
}
defer keyFile.Close()
if err := pem.Encode(certFile, &pem.Block{
Type: "CERTIFICATE",
Bytes: der,
}); err != nil {
return "", "", err
}
if err := pem.Encode(keyFile, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(keys),
}); err != nil {
return "", "", err
}
return certFile.Name(), keyFile.Name(), nil
}
func getHostsPath() (string, error) {
switch runtime.GOOS {
case "freebsd":
fallthrough
case "openbsd":
fallthrough
case "dragonfly":
fallthrough
case "netbsd":
fallthrough
case "darwin":
fallthrough
case "android":
fallthrough
case "linux":
return "/etc/hosts", nil
case "windows":
return `C:\Windows\System32\drivers\etc\hosts`, nil
}
return "", fmt.Errorf("unknown os: %s", runtime.GOOS)
}
var hostsModRE = regexp.MustCompile("(?:\n|\r\n|\r)#sati-bridge start, DO NOT MODIFY(?:\n|\r\n|\r)[^#]*(?:\n|\r\n|\r)#sati-bridge end")
var configPath = flag.String("config", "./data/config.json", "config path")
func addDomainsToHosts(ctx *api.ApiContext) error {
ctx.Logger.Info("adding domains to hosts")
path, err := getHostsPath()
if err != nil {
ctx.Logger.WithError(err).Warn("unable to get hosts path")
return err
}
hosts, err := os.ReadFile(path)
if err != nil {
ctx.Logger.WithError(err).Warn("unable to read hosts file")
return err
}
hosts = hostsModRE.ReplaceAll(hosts, []byte{}) // remove old entries
hostIp := strings.SplitN(ctx.Config.Host, ":", 2)[0]
suffix := "\r\n#sati-bridge start, DO NOT MODIFY\r\n"
for _, domain := range ctx.Server.GetDomains() {
suffix += hostIp + " " + domain + "\r\n"
}
suffix += "#sati-bridge end\r\n"
hosts = []byte(strings.TrimRight(string(hosts), "\r\n\t ") + suffix)
err = os.WriteFile(path, hosts, 0644)
if err != nil {
ctx.Logger.WithError(err).Warn("unable to write hosts file")
}
return err
}
func removeDomainsFromHosts(ctx *api.ApiContext) error {
ctx.Logger.Info("removing domains from hosts")
path, err := getHostsPath()
if err != nil {
ctx.Logger.WithError(err).Warn("unable to get hosts path")
return err
}
hosts, err := os.ReadFile(path)
if err != nil {
ctx.Logger.WithError(err).Warn("unable to read hosts file")
return err
}
hosts = hostsModRE.ReplaceAll(hosts, []byte{})
hosts = bytes.TrimRight(hosts, "\r\n\t ")
err = os.WriteFile(path, hosts, 0644)
if err != nil {
ctx.Logger.WithError(err).Warn("unable to write hosts file")
}
return err
}
func main() {
logger := logrus.New()
logger.Info("starting")
cfg := config.Default()
cfg.Path = *configPath
if err := cfg.Load(); err != nil {
logger.Info("failed to load config: ", err.Error(), ". attempting to create new")
if err := cfg.Save(); err != nil {
logger.Panic("failed to create config: ", err.Error())
}
}
if cfg.Debug {
logger.SetLevel(logrus.DebugLevel)
} else {
logger.SetLevel(logrus.InfoLevel)
}
_, certReadErr := os.Stat(cfg.TlsCertPath)
_, keyReadErr := os.Stat(cfg.TlsKeyPath)
if certReadErr != nil || keyReadErr != nil {
logger.Info("CA certificate or key not found, issuing new")
err := issueCACert(cfg.TlsCertPath, cfg.TlsKeyPath)
if err != nil {
logger.Panic("failed to issue CA certificate: ", err.Error())
}
}
if cfg.Token == "" {
logger.Fatal("api token not specified, get it at https://sati.ac/dashboard")
}
if strings.HasPrefix(cfg.Host, "0.0.0.0:") || strings.HasPrefix(cfg.TlsHost, "0.0.0.0:") {
logger.Warn("you are trying to listen on all interfaces, THIS IS INSECURE")
}
satiConfig := sati.NewConfig(cfg.Token)
satiConfig.Debug = cfg.Debug
satiApi := sati.NewApi(satiConfig)
registry := api.NewTaskRegistry(satiApi, cfg)
ctx := api.ApiContext{
Config: cfg,
Api: satiApi,
Registry: registry,
Logger: logger,
}
server := api.NewApiServer(&ctx)
ctx.Logger.WithFields(logrus.Fields{
"domains": server.GetDomains(),
}).Debug("api server created")
logger.Info("issuing ephemeral certificate")
certFile, keyFile, err := issueCert(server.GetDomains(), cfg.TlsCertPath, cfg.TlsKeyPath)
if err != nil {
logger.Panic(err)
}
defer os.Remove(certFile)
defer os.Remove(keyFile)
if addDomainsToHosts(&ctx) == nil {
defer removeDomainsFromHosts(&ctx)
}
logger.Info("starting api server")
terminator := make(chan error)
go func() { terminator <- http.ListenAndServe(cfg.Host, server) }()
go func() { terminator <- http.ListenAndServeTLS(cfg.TlsHost, certFile, keyFile, server) }()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
go func() {
<-c
terminator <- fmt.Errorf("interrupted")
}()
logger.Error(<-terminator)
}