package main import ( "bytes" "crypto/rand" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "flag" "fmt" "math/big" "net/http" "os" "os/signal" "regexp" "runtime" "strings" "time" "git.sati.ac/sati.ac/bridge/api" "git.sati.ac/sati.ac/bridge/config" "git.sati.ac/sati.ac/sati-go" "github.com/sirupsen/logrus" ) type bridge struct { config *config.Config } func issueCACert(certPath string, keyPath string) error { cert := &x509.Certificate{ SerialNumber: big.NewInt(0), Subject: pkix.Name{ CommonName: "Bridge CA", Country: []string{"VA"}, Organization: []string{"sati.ac"}, Locality: []string{"Everywhere"}, }, NotBefore: time.Now(), NotAfter: time.Now().AddDate(100, 0, 0), IsCA: true, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, } keys, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return err } der, err := x509.CreateCertificate(rand.Reader, cert, cert, &keys.PublicKey, keys) if err != nil { return err } certFile, err := os.OpenFile(certPath, os.O_CREATE|os.O_WRONLY, 0600) if err != nil { return err } defer certFile.Close() if err := pem.Encode(certFile, &pem.Block{ Type: "CERTIFICATE", Bytes: der, }); err != nil { return err } keyFile, err := os.OpenFile(keyPath, os.O_CREATE|os.O_WRONLY, 0600) if err != nil { return err } defer keyFile.Close() if err := pem.Encode(keyFile, &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(keys), }); err != nil { return err } return nil } func issueCert(domains []string, caCertPath string, caKeyPath string) (string, string, error) { cert := &x509.Certificate{ SerialNumber: big.NewInt(time.Now().Unix()), Subject: pkix.Name{ Country: []string{"VA"}, Organization: []string{"sati.ac"}, OrganizationalUnit: []string{"Bridge ephemeral certificate"}, Locality: []string{"Everywhere"}, }, DNSNames: domains, NotBefore: time.Now(), NotAfter: time.Now().AddDate(100, 0, 0), SubjectKeyId: []byte{0}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature, } keys, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return "", "", err } caCertPem, err := os.ReadFile(caCertPath) if err != nil { return "", "", err } block, _ := pem.Decode(caCertPem) if block == nil || block.Type != "CERTIFICATE" { return "", "", fmt.Errorf(`certificate: bad pem block "%s"`, block.Type) } caCert, err := x509.ParseCertificate(block.Bytes) if err != nil { return "", "", err } caKeyPem, err := os.ReadFile(caKeyPath) if err != nil { return "", "", err } block, _ = pem.Decode(caKeyPem) if block == nil || block.Type != "RSA PRIVATE KEY" { return "", "", fmt.Errorf(`key: bad pem block "%s"`, block.Type) } caKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return "", "", err } der, err := x509.CreateCertificate(rand.Reader, cert, caCert, &keys.PublicKey, caKey) if err != nil { return "", "", err } certFile, err := os.CreateTemp("", "bridge*.crt") if err != nil { return "", "", err } defer certFile.Close() keyFile, err := os.CreateTemp("", "bridge*.key") if err != nil { return "", "", err } defer keyFile.Close() if err := pem.Encode(certFile, &pem.Block{ Type: "CERTIFICATE", Bytes: der, }); err != nil { return "", "", err } if err := pem.Encode(keyFile, &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(keys), }); err != nil { return "", "", err } return certFile.Name(), keyFile.Name(), nil } func getHostsPath() (string, error) { switch runtime.GOOS { case "freebsd": fallthrough case "openbsd": fallthrough case "dragonfly": fallthrough case "netbsd": fallthrough case "darwin": fallthrough case "android": fallthrough case "linux": return "/etc/hosts", nil case "windows": return `C:\Windows\System32\drivers\etc\hosts`, nil } return "", fmt.Errorf("unknown os: %s", runtime.GOOS) } var hostsModRE = regexp.MustCompile("(?:\n|\r\n|\r)#sati-bridge start, DO NOT MODIFY(?:\n|\r\n|\r)[^#]*(?:\n|\r\n|\r)#sati-bridge end") var configPath = flag.String("config", "./data/config.json", "config path") func addDomainsToHosts(ctx *api.ApiContext) error { ctx.Logger.Info("adding domains to hosts") path, err := getHostsPath() if err != nil { ctx.Logger.WithError(err).Warn("unable to get hosts path") return err } hosts, err := os.ReadFile(path) if err != nil { ctx.Logger.WithError(err).Warn("unable to read hosts file") return err } hosts = hostsModRE.ReplaceAll(hosts, []byte{}) // remove old entries hostIp := strings.SplitN(ctx.Config.Host, ":", 2)[0] suffix := "\r\n#sati-bridge start, DO NOT MODIFY\r\n" for _, domain := range ctx.Server.GetDomains() { suffix += hostIp + " " + domain + "\r\n" } suffix += "#sati-bridge end\r\n" hosts = []byte(strings.TrimRight(string(hosts), "\r\n\t ") + suffix) err = os.WriteFile(path, hosts, 0644) if err != nil { ctx.Logger.WithError(err).Warn("unable to write hosts file") } return err } func removeDomainsFromHosts(ctx *api.ApiContext) error { ctx.Logger.Info("removing domains from hosts") path, err := getHostsPath() if err != nil { ctx.Logger.WithError(err).Warn("unable to get hosts path") return err } hosts, err := os.ReadFile(path) if err != nil { ctx.Logger.WithError(err).Warn("unable to read hosts file") return err } hosts = hostsModRE.ReplaceAll(hosts, []byte{}) hosts = bytes.TrimRight(hosts, "\r\n\t ") err = os.WriteFile(path, hosts, 0644) if err != nil { ctx.Logger.WithError(err).Warn("unable to write hosts file") } return err } func main() { logger := logrus.New() logger.Info("starting") cfg := config.Default() cfg.Path = *configPath if err := cfg.Load(); err != nil { logger.Info("failed to load config: ", err.Error(), ". attempting to create new") if err := cfg.Save(); err != nil { logger.Panic("failed to create config: ", err.Error()) } } if cfg.Debug { logger.SetLevel(logrus.DebugLevel) } else { logger.SetLevel(logrus.InfoLevel) } _, certReadErr := os.Stat(cfg.TlsCertPath) _, keyReadErr := os.Stat(cfg.TlsKeyPath) if certReadErr != nil || keyReadErr != nil { logger.Info("CA certificate or key not found, issuing new") err := issueCACert(cfg.TlsCertPath, cfg.TlsKeyPath) if err != nil { logger.Panic("failed to issue CA certificate: ", err.Error()) } } if cfg.Token == "" { logger.Fatal("api token not specified, get it at https://sati.ac/dashboard") } if strings.HasPrefix(cfg.Host, "0.0.0.0:") || strings.HasPrefix(cfg.TlsHost, "0.0.0.0:") { logger.Warn("you are trying to listen on all interfaces, THIS IS INSECURE") } satiConfig := sati.NewConfig(cfg.Token) satiConfig.Debug = cfg.Debug satiApi := sati.NewApi(satiConfig) registry := api.NewTaskRegistry(satiApi, cfg) ctx := api.ApiContext{ Config: cfg, Api: satiApi, Registry: registry, Logger: logger, } server := api.NewApiServer(&ctx) ctx.Logger.WithFields(logrus.Fields{ "domains": server.GetDomains(), }).Debug("api server created") logger.Info("issuing ephemeral certificate") certFile, keyFile, err := issueCert(server.GetDomains(), cfg.TlsCertPath, cfg.TlsKeyPath) if err != nil { logger.Panic(err) } defer os.Remove(certFile) defer os.Remove(keyFile) if addDomainsToHosts(&ctx) == nil { defer removeDomainsFromHosts(&ctx) } logger.Info("starting api server") terminator := make(chan error) go func() { terminator <- http.ListenAndServe(cfg.Host, server) }() go func() { terminator <- http.ListenAndServeTLS(cfg.TlsHost, certFile, keyFile, server) }() c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) go func() { <-c terminator <- fmt.Errorf("interrupted") }() logger.Error(<-terminator) }