package main

import (
	"bytes"
	"errors"
	"flag"
	"fmt"
	"io"
	"net"
	"os"
	"runtime"
	"strings"
	"time"

	"github.com/golang/glog"
	"github.com/joho/godotenv"
	"github.com/openshift/machine-config-operator/pkg/version"
	"github.com/spf13/cobra"
	"k8s.io/apimachinery/pkg/util/wait"
)

var (
	runCmd = &cobra.Command{
		Use:   "run",
		Short: "Runs the setup-etcd-environment",
		Long:  "",
		RunE:  runRunCmd,
	}

	runOpts struct {
		discoverySRV string
		ifName       string
		outputFile   string
		bootstrapSRV bool
	}
)

func init() {
	rootCmd.AddCommand(runCmd)
	rootCmd.PersistentFlags().StringVar(&runOpts.discoverySRV, "discovery-srv", "", "DNS domain used to populate envs from SRV query.")
	rootCmd.PersistentFlags().StringVar(&runOpts.outputFile, "output-file", "", "file where the envs are written. If empty, prints to Stdout.")
	rootCmd.PersistentFlags().BoolVar(&runOpts.bootstrapSRV, "bootstrap-srv", true, "use SRV discovery for bootstraping etcd cluster.")
}

func runRunCmd(cmd *cobra.Command, args []string) error {
	flag.Set("logtostderr", "true")
	flag.Parse()

	// To help debugging, immediately log version
	glog.Infof("Version: %+v (%s)", version.Raw, version.Hash)

	if runOpts.discoverySRV == "" {
		return errors.New("--discovery-srv cannot be empty")
	}

	ips, err := ipAddrs()
	if err != nil {
		return err
	}

	var dns string
	var ip string
	if err := wait.PollImmediate(30*time.Second, 5*time.Minute, func() (bool, error) {
		for _, cand := range ips {
			found, err := reverseLookupSelf("etcd-server-ssl", "tcp", runOpts.discoverySRV, cand)
			if err != nil {
				glog.Errorf("error looking up self for candidate IP %s: %v", cand, err)
				continue
			}
			if found != "" {
				dns = found
				ip = cand
				return true, nil
			}
			glog.V(4).Infof("no matching dns for %s", cand)
		}
		return false, nil
	}); err != nil {
		return fmt.Errorf("could not find self: %v", err)
	}
	glog.Infof("dns name is %s", dns)

	// initialize envs used to bootstrap etcd
	exportEnv, err := setBootstrapEnv(runOpts.outputFile, runOpts.discoverySRV, runOpts.bootstrapSRV)
	if err != nil {
		return err
	}

	out := os.Stdout
	if runOpts.outputFile != "" {
		f, err := os.Create(runOpts.outputFile)
		if err != nil {
			return err
		}
		defer f.Close()
		out = f
	}

	// enable etcd to run using s390 and s390x. Because these are not officially supported upstream
	// etcd requires population of environment variable ETCD_UNSUPPORTED_ARCH at runtime.
	// https://github.com/etcd-io/etcd/blob/master/Documentation/op-guide/supported-platform.md
	arch := runtime.GOARCH
	if strings.HasPrefix(arch, "s390") {
		exportEnv["UNSUPPORTED_ARCH"] = arch
	}
	if err := writeEnvironmentFile(exportEnv, out, true); err != nil {
		return err
	}

	parsedIP := net.ParseIP(ip)
	if parsedIP == nil {
		return fmt.Errorf("Failed to parse IP '%s'", ip)
	}

	escapedIP := ip
	escapedAllIPs := "0.0.0.0"
	localhostIP := "127.0.0.1"
	escapedLocalhostIP := "127.0.0.1"
	if parsedIP.To4() == nil {
		// This is an IPv6 address, not IPv4.

		// When using an IPv6 address in a URL, we must wrap the address portion in
		// [::] so that a ":port" suffix can still be added and parsed correctly.
		escapedIP = fmt.Sprintf("[%s]", ip)
		escapedAllIPs = "[::]"
		localhostIP = "::1"
		escapedLocalhostIP = "[::1]"
	}

	return writeEnvironmentFile(map[string]string{
		// TODO This can actually be IPv6, so we should rename this ...
		"IPV4_ADDRESS":         ip,
		"ESCAPED_IP_ADDRESS":   escapedIP,
		"ESCAPED_ALL_IPS":      escapedAllIPs,
		"LOCALHOST_IP":         localhostIP,
		"ESCAPED_LOCALHOST_IP": escapedLocalhostIP,
		"DNS_NAME":             dns,
		"WILDCARD_DNS_NAME":    fmt.Sprintf("*.%s", runOpts.discoverySRV),
	}, out, false)
}

// setBootstrapEnv populates and returns a map based on envs from file.
func setBootstrapEnv(envFile, discoverySRV string, bootstrapSRV bool) (map[string]string, error) {
	bootstrapEnv := make(map[string]string)
	if _, err := os.Stat(envFile); !os.IsNotExist(err) {
		recoveryEnv, err := godotenv.Read(envFile)
		if err != nil {
			return nil, err
		}
		// persist any observed envs used for recovery
		for _, val := range []string{"INITIAL_CLUSTER", "INITIAL_CLUSTER_STATE"} {
			rkey := fmt.Sprintf("ETCD_%s", val)
			if recoveryEnv[rkey] != "" {
				bootstrapEnv[val] = recoveryEnv[rkey]
			}
		}
	}
	// define srv discovery if enabeled and not in recovery mode
	if len(bootstrapEnv) == 0 && bootstrapSRV {
		bootstrapEnv["DISCOVERY_SRV"] = discoverySRV
	}
	return bootstrapEnv, nil
}

func ipAddrs() ([]string, error) {
	var ips []string
	addrs, err := net.InterfaceAddrs()
	if err != nil {
		return ips, err
	}
	for _, addr := range addrs {
		var ip net.IP
		switch v := addr.(type) {
		case *net.IPNet:
			ip = v.IP
		case *net.IPAddr:
			ip = v.IP
		}
		if ip == nil {
			continue
		}
		if !ip.IsGlobalUnicast() {
			continue // we only want global unicast address
		}
		ips = append(ips, ip.String())
	}

	return ips, nil
}

// returns the target from the SRV record that resolves to self.
func reverseLookupSelf(service, proto, name, self string) (string, error) {
	_, srvs, err := net.LookupSRV(service, proto, name)
	if err != nil {
		return "", err
	}
	selfTarget := ""
	for _, srv := range srvs {
		glog.V(4).Infof("checking against %s", srv.Target)
		addrs, err := net.LookupHost(srv.Target)
		if err != nil {
			return "", fmt.Errorf("could not resolve member %q", srv.Target)
		}

		for _, addr := range addrs {
			if addr == self {
				selfTarget = strings.Trim(srv.Target, ".")
				break
			}
		}
	}
	if selfTarget == "" {
		return "", fmt.Errorf("could not find self")
	}
	return selfTarget, nil
}

func writeEnvironmentFile(m map[string]string, w io.Writer, export bool) error {
	var buffer bytes.Buffer
	for k, v := range m {
		env := fmt.Sprintf("ETCD_%s=%s\n", k, v)
		if export == true {
			env = fmt.Sprintf("export %s", env)
		}
		buffer.WriteString(env)
	}
	if _, err := buffer.WriteTo(w); err != nil {
		return err
	}
	return nil
}
