333 lines
7.9 KiB
Go
333 lines
7.9 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"flag"
|
|
"fmt"
|
|
"math/big"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"regexp"
|
|
"runtime"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.sati.ac/sati.ac/bridge/api"
|
|
"git.sati.ac/sati.ac/bridge/config"
|
|
"git.sati.ac/sati.ac/sati-go"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
type bridge struct {
|
|
config *config.Config
|
|
}
|
|
|
|
func issueCACert(certPath string, keyPath string) error {
|
|
cert := &x509.Certificate{
|
|
SerialNumber: big.NewInt(0),
|
|
Subject: pkix.Name{
|
|
CommonName: "Bridge CA",
|
|
Country: []string{"VA"},
|
|
Organization: []string{"sati.ac"},
|
|
Locality: []string{"Everywhere"},
|
|
},
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().AddDate(100, 0, 0),
|
|
IsCA: true,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
|
BasicConstraintsValid: true,
|
|
}
|
|
keys, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
der, err := x509.CreateCertificate(rand.Reader, cert, cert, &keys.PublicKey, keys)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
certFile, err := os.OpenFile(certPath, os.O_CREATE|os.O_WRONLY, 0600)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer certFile.Close()
|
|
|
|
if err := pem.Encode(certFile, &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: der,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
keyFile, err := os.OpenFile(keyPath, os.O_CREATE|os.O_WRONLY, 0600)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer keyFile.Close()
|
|
|
|
if err := pem.Encode(keyFile, &pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: x509.MarshalPKCS1PrivateKey(keys),
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func issueCert(domains []string, caCertPath string, caKeyPath string) (string, string, error) {
|
|
cert := &x509.Certificate{
|
|
SerialNumber: big.NewInt(time.Now().Unix()),
|
|
Subject: pkix.Name{
|
|
Country: []string{"VA"},
|
|
Organization: []string{"sati.ac"},
|
|
OrganizationalUnit: []string{"Bridge ephemeral certificate"},
|
|
Locality: []string{"Everywhere"},
|
|
},
|
|
DNSNames: domains,
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().AddDate(100, 0, 0),
|
|
SubjectKeyId: []byte{0},
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
|
}
|
|
|
|
keys, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
caCertPem, err := os.ReadFile(caCertPath)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
block, _ := pem.Decode(caCertPem)
|
|
if block == nil || block.Type != "CERTIFICATE" {
|
|
return "", "", fmt.Errorf(`certificate: bad pem block "%s"`, block.Type)
|
|
}
|
|
caCert, err := x509.ParseCertificate(block.Bytes)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
caKeyPem, err := os.ReadFile(caKeyPath)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
block, _ = pem.Decode(caKeyPem)
|
|
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
|
return "", "", fmt.Errorf(`key: bad pem block "%s"`, block.Type)
|
|
}
|
|
caKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
der, err := x509.CreateCertificate(rand.Reader, cert, caCert, &keys.PublicKey, caKey)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
certFile, err := os.CreateTemp("", "bridge*.crt")
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
defer certFile.Close()
|
|
keyFile, err := os.CreateTemp("", "bridge*.key")
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
defer keyFile.Close()
|
|
|
|
if err := pem.Encode(certFile, &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: der,
|
|
}); err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
if err := pem.Encode(keyFile, &pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: x509.MarshalPKCS1PrivateKey(keys),
|
|
}); err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
return certFile.Name(), keyFile.Name(), nil
|
|
}
|
|
|
|
func getHostsPath() (string, error) {
|
|
switch runtime.GOOS {
|
|
case "freebsd":
|
|
fallthrough
|
|
case "openbsd":
|
|
fallthrough
|
|
case "dragonfly":
|
|
fallthrough
|
|
case "netbsd":
|
|
fallthrough
|
|
case "darwin":
|
|
fallthrough
|
|
case "android":
|
|
fallthrough
|
|
case "linux":
|
|
return "/etc/hosts", nil
|
|
case "windows":
|
|
return `C:\Windows\System32\drivers\etc\hosts`, nil
|
|
}
|
|
|
|
return "", fmt.Errorf("unknown os: %s", runtime.GOOS)
|
|
}
|
|
|
|
var hostsModRE = regexp.MustCompile("(?:\n|\r\n|\r)#sati-bridge start, DO NOT MODIFY(?:\n|\r\n|\r)[^#]*(?:\n|\r\n|\r)#sati-bridge end")
|
|
|
|
var configPath = flag.String("config", "./data/config.json", "config path")
|
|
|
|
func addDomainsToHosts(ctx *api.ApiContext) error {
|
|
ctx.Logger.Info("adding domains to hosts")
|
|
path, err := getHostsPath()
|
|
if err != nil {
|
|
ctx.Logger.WithError(err).Warn("unable to get hosts path")
|
|
return err
|
|
}
|
|
|
|
hosts, err := os.ReadFile(path)
|
|
if err != nil {
|
|
ctx.Logger.WithError(err).Warn("unable to read hosts file")
|
|
return err
|
|
}
|
|
|
|
hosts = hostsModRE.ReplaceAll(hosts, []byte{}) // remove old entries
|
|
|
|
hostIp := strings.SplitN(ctx.Config.Host, ":", 2)[0]
|
|
suffix := "\r\n#sati-bridge start, DO NOT MODIFY\r\n"
|
|
for _, domain := range ctx.Server.GetDomains() {
|
|
suffix += hostIp + " " + domain + "\r\n"
|
|
}
|
|
suffix += "#sati-bridge end\r\n"
|
|
hosts = []byte(strings.TrimRight(string(hosts), "\r\n\t ") + suffix)
|
|
|
|
err = os.WriteFile(path, hosts, 0644)
|
|
if err != nil {
|
|
ctx.Logger.WithError(err).Warn("unable to write hosts file")
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func removeDomainsFromHosts(ctx *api.ApiContext) error {
|
|
ctx.Logger.Info("removing domains from hosts")
|
|
path, err := getHostsPath()
|
|
if err != nil {
|
|
ctx.Logger.WithError(err).Warn("unable to get hosts path")
|
|
return err
|
|
}
|
|
|
|
hosts, err := os.ReadFile(path)
|
|
if err != nil {
|
|
ctx.Logger.WithError(err).Warn("unable to read hosts file")
|
|
return err
|
|
}
|
|
|
|
hosts = hostsModRE.ReplaceAll(hosts, []byte{})
|
|
hosts = bytes.TrimRight(hosts, "\r\n\t ")
|
|
|
|
err = os.WriteFile(path, hosts, 0644)
|
|
if err != nil {
|
|
ctx.Logger.WithError(err).Warn("unable to write hosts file")
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func main() {
|
|
logger := logrus.New()
|
|
logger.Info("starting")
|
|
|
|
cfg := config.Default()
|
|
cfg.Path = *configPath
|
|
if err := cfg.Load(); err != nil {
|
|
logger.Info("failed to load config: ", err.Error(), ". attempting to create new")
|
|
if err := cfg.Save(); err != nil {
|
|
logger.Panic("failed to create config: ", err.Error())
|
|
}
|
|
}
|
|
|
|
if cfg.Debug {
|
|
logger.SetLevel(logrus.DebugLevel)
|
|
} else {
|
|
logger.SetLevel(logrus.InfoLevel)
|
|
}
|
|
|
|
_, certReadErr := os.Stat(cfg.TlsCertPath)
|
|
_, keyReadErr := os.Stat(cfg.TlsKeyPath)
|
|
if certReadErr != nil || keyReadErr != nil {
|
|
logger.Info("CA certificate or key not found, issuing new")
|
|
err := issueCACert(cfg.TlsCertPath, cfg.TlsKeyPath)
|
|
if err != nil {
|
|
logger.Panic("failed to issue CA certificate: ", err.Error())
|
|
}
|
|
}
|
|
|
|
if cfg.Token == "" {
|
|
logger.Fatal("api token not specified, get it at https://sati.ac/dashboard")
|
|
}
|
|
|
|
if strings.HasPrefix(cfg.Host, "0.0.0.0:") || strings.HasPrefix(cfg.TlsHost, "0.0.0.0:") {
|
|
logger.Warn("you are trying to listen on all interfaces, THIS IS INSECURE")
|
|
}
|
|
|
|
satiConfig := sati.NewConfig(cfg.Token)
|
|
satiConfig.Debug = cfg.Debug
|
|
satiApi := sati.NewApi(satiConfig)
|
|
|
|
registry := api.NewTaskRegistry(satiApi, cfg)
|
|
|
|
ctx := api.ApiContext{
|
|
Config: cfg,
|
|
Api: satiApi,
|
|
Registry: registry,
|
|
Logger: logger,
|
|
}
|
|
|
|
server := api.NewApiServer(&ctx)
|
|
ctx.Logger.WithFields(logrus.Fields{
|
|
"domains": server.GetDomains(),
|
|
}).Debug("api server created")
|
|
|
|
logger.Info("issuing ephemeral certificate")
|
|
certFile, keyFile, err := issueCert(server.GetDomains(), cfg.TlsCertPath, cfg.TlsKeyPath)
|
|
if err != nil {
|
|
logger.Panic(err)
|
|
}
|
|
defer os.Remove(certFile)
|
|
defer os.Remove(keyFile)
|
|
|
|
if addDomainsToHosts(&ctx) == nil {
|
|
defer removeDomainsFromHosts(&ctx)
|
|
}
|
|
|
|
logger.Info("starting api server")
|
|
|
|
terminator := make(chan error)
|
|
go func() { terminator <- http.ListenAndServe(cfg.Host, server) }()
|
|
go func() { terminator <- http.ListenAndServeTLS(cfg.TlsHost, certFile, keyFile, server) }()
|
|
|
|
c := make(chan os.Signal, 1)
|
|
signal.Notify(c, os.Interrupt)
|
|
go func() {
|
|
<-c
|
|
terminator <- fmt.Errorf("interrupted")
|
|
}()
|
|
|
|
logger.Error(<-terminator)
|
|
}
|