This commit is contained in:
332
main.go
Normal file
332
main.go
Normal file
@ -0,0 +1,332 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.sati.ac/sati.ac/bridge/api"
|
||||
"git.sati.ac/sati.ac/bridge/config"
|
||||
"git.sati.ac/sati.ac/sati-go"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type bridge struct {
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func issueCACert(certPath string, keyPath string) error {
|
||||
cert := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(0),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "Bridge CA",
|
||||
Country: []string{"VA"},
|
||||
Organization: []string{"sati.ac"},
|
||||
Locality: []string{"Everywhere"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(100, 0, 0),
|
||||
IsCA: true,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
keys, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
der, err := x509.CreateCertificate(rand.Reader, cert, cert, &keys.PublicKey, keys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certFile, err := os.OpenFile(certPath, os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer certFile.Close()
|
||||
|
||||
if err := pem.Encode(certFile, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: der,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyFile, err := os.OpenFile(keyPath, os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer keyFile.Close()
|
||||
|
||||
if err := pem.Encode(keyFile, &pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(keys),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func issueCert(domains []string, caCertPath string, caKeyPath string) (string, string, error) {
|
||||
cert := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(time.Now().Unix()),
|
||||
Subject: pkix.Name{
|
||||
Country: []string{"VA"},
|
||||
Organization: []string{"sati.ac"},
|
||||
OrganizationalUnit: []string{"Bridge ephemeral certificate"},
|
||||
Locality: []string{"Everywhere"},
|
||||
},
|
||||
DNSNames: domains,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(100, 0, 0),
|
||||
SubjectKeyId: []byte{0},
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
}
|
||||
|
||||
keys, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
caCertPem, err := os.ReadFile(caCertPath)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
block, _ := pem.Decode(caCertPem)
|
||||
if block == nil || block.Type != "CERTIFICATE" {
|
||||
return "", "", fmt.Errorf(`certificate: bad pem block "%s"`, block.Type)
|
||||
}
|
||||
caCert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
caKeyPem, err := os.ReadFile(caKeyPath)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
block, _ = pem.Decode(caKeyPem)
|
||||
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
||||
return "", "", fmt.Errorf(`key: bad pem block "%s"`, block.Type)
|
||||
}
|
||||
caKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
der, err := x509.CreateCertificate(rand.Reader, cert, caCert, &keys.PublicKey, caKey)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
certFile, err := os.CreateTemp("", "bridge*.crt")
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer certFile.Close()
|
||||
keyFile, err := os.CreateTemp("", "bridge*.key")
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer keyFile.Close()
|
||||
|
||||
if err := pem.Encode(certFile, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: der,
|
||||
}); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if err := pem.Encode(keyFile, &pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(keys),
|
||||
}); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return certFile.Name(), keyFile.Name(), nil
|
||||
}
|
||||
|
||||
func getHostsPath() (string, error) {
|
||||
switch runtime.GOOS {
|
||||
case "freebsd":
|
||||
fallthrough
|
||||
case "openbsd":
|
||||
fallthrough
|
||||
case "dragonfly":
|
||||
fallthrough
|
||||
case "netbsd":
|
||||
fallthrough
|
||||
case "darwin":
|
||||
fallthrough
|
||||
case "android":
|
||||
fallthrough
|
||||
case "linux":
|
||||
return "/etc/hosts", nil
|
||||
case "windows":
|
||||
return `C:\Windows\System32\drivers\etc\hosts`, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unknown os: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
var hostsModRE = regexp.MustCompile("(?:\n|\r\n|\r)#sati-bridge start, DO NOT MODIFY(?:\n|\r\n|\r)[^#]*(?:\n|\r\n|\r)#sati-bridge end")
|
||||
|
||||
var configPath = flag.String("config", "./data/config.json", "config path")
|
||||
|
||||
func addDomainsToHosts(ctx *api.ApiContext) error {
|
||||
ctx.Logger.Info("adding domains to hosts")
|
||||
path, err := getHostsPath()
|
||||
if err != nil {
|
||||
ctx.Logger.WithError(err).Warn("unable to get hosts path")
|
||||
return err
|
||||
}
|
||||
|
||||
hosts, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
ctx.Logger.WithError(err).Warn("unable to read hosts file")
|
||||
return err
|
||||
}
|
||||
|
||||
hosts = hostsModRE.ReplaceAll(hosts, []byte{}) // remove old entries
|
||||
|
||||
hostIp := "127.0.0.1"
|
||||
suffix := "\r\n#sati-bridge start, DO NOT MODIFY\r\n"
|
||||
for _, domain := range ctx.Server.GetDomains() {
|
||||
suffix += hostIp + " " + domain + "\r\n"
|
||||
}
|
||||
suffix += "#sati-bridge end\r\n"
|
||||
hosts = []byte(strings.TrimRight(string(hosts), "\r\n\t ") + suffix)
|
||||
|
||||
err = os.WriteFile(path, hosts, 0644)
|
||||
if err != nil {
|
||||
ctx.Logger.WithError(err).Warn("unable to write hosts file")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func removeDomainsFromHosts(ctx *api.ApiContext) error {
|
||||
ctx.Logger.Info("removing domains from hosts")
|
||||
path, err := getHostsPath()
|
||||
if err != nil {
|
||||
ctx.Logger.WithError(err).Warn("unable to get hosts path")
|
||||
return err
|
||||
}
|
||||
|
||||
hosts, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
ctx.Logger.WithError(err).Warn("unable to read hosts file")
|
||||
return err
|
||||
}
|
||||
|
||||
hosts = hostsModRE.ReplaceAll(hosts, []byte{})
|
||||
hosts = bytes.TrimRight(hosts, "\r\n\t ")
|
||||
|
||||
err = os.WriteFile(path, hosts, 0644)
|
||||
if err != nil {
|
||||
ctx.Logger.WithError(err).Warn("unable to write hosts file")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func main() {
|
||||
logger := logrus.New()
|
||||
logger.Info("starting")
|
||||
|
||||
cfg := config.Default()
|
||||
cfg.Path = *configPath
|
||||
if err := cfg.Load(); err != nil {
|
||||
logger.Info("failed to load config: ", err.Error(), ". attempting to create new")
|
||||
if err := cfg.Save(); err != nil {
|
||||
logger.Panic("failed to create config: ", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Debug {
|
||||
logger.SetLevel(logrus.DebugLevel)
|
||||
} else {
|
||||
logger.SetLevel(logrus.InfoLevel)
|
||||
}
|
||||
|
||||
_, certReadErr := os.Stat(cfg.TlsCertPath)
|
||||
_, keyReadErr := os.Stat(cfg.TlsKeyPath)
|
||||
if certReadErr != nil || keyReadErr != nil {
|
||||
logger.Info("CA certificate or key not found, issuing new")
|
||||
err := issueCACert(cfg.TlsCertPath, cfg.TlsKeyPath)
|
||||
if err != nil {
|
||||
logger.Panic("failed to issue CA certificate: ", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Token == "" {
|
||||
logger.Fatal("api token not specified, get it at https://sati.ac/dashboard")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(cfg.Host, "0.0.0.0:") || strings.HasPrefix(cfg.TlsHost, "0.0.0.0:") {
|
||||
logger.Warn("you are trying to listen on all interfaces, THIS IS INSECURE")
|
||||
}
|
||||
|
||||
satiConfig := sati.NewConfig(cfg.Token)
|
||||
satiConfig.Debug = cfg.Debug
|
||||
satiApi := sati.NewApi(satiConfig)
|
||||
|
||||
registry := api.NewTaskRegistry(satiApi, time.Minute)
|
||||
|
||||
ctx := api.ApiContext{
|
||||
Config: cfg,
|
||||
Api: satiApi,
|
||||
Registry: registry,
|
||||
Logger: logger,
|
||||
}
|
||||
|
||||
server := api.NewApiServer(&ctx)
|
||||
ctx.Logger.WithFields(logrus.Fields{
|
||||
"domains": server.GetDomains(),
|
||||
}).Debug("api server created")
|
||||
|
||||
logger.Info("issuing ephemeral certificate")
|
||||
certFile, keyFile, err := issueCert(server.GetDomains(), cfg.TlsCertPath, cfg.TlsKeyPath)
|
||||
if err != nil {
|
||||
logger.Panic(err)
|
||||
}
|
||||
defer os.Remove(certFile)
|
||||
defer os.Remove(keyFile)
|
||||
|
||||
if addDomainsToHosts(&ctx) == nil {
|
||||
defer removeDomainsFromHosts(&ctx)
|
||||
}
|
||||
|
||||
logger.Info("starting api server")
|
||||
|
||||
terminator := make(chan error)
|
||||
go func() { terminator <- http.ListenAndServe(cfg.Host, server) }()
|
||||
go func() { terminator <- http.ListenAndServeTLS(cfg.TlsHost, certFile, keyFile, server) }()
|
||||
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt)
|
||||
go func() {
|
||||
<-c
|
||||
terminator <- fmt.Errorf("interrupted")
|
||||
}()
|
||||
|
||||
logger.Error(<-terminator)
|
||||
}
|
Reference in New Issue
Block a user