From 6d40de2179a644a722ee314554ecac5df9f7fdda Mon Sep 17 00:00:00 2001 From: Skyxim Date: Mon, 27 Mar 2023 22:27:59 +0800 Subject: [PATCH] chore: adjust trust cert --- component/tls/config.go | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/component/tls/config.go b/component/tls/config.go index f0155d782..91b89f1dc 100644 --- a/component/tls/config.go +++ b/component/tls/config.go @@ -14,7 +14,7 @@ import ( xtls "github.com/xtls/go" ) -var trustCert, _ = x509.SystemCertPool() +var trustCerts []*x509.Certificate var mutex sync.RWMutex var errNotMacth error = errors.New("certificate fingerprints do not match") @@ -25,16 +25,28 @@ func AddCertificate(certificate string) error { if certificate == "" { return fmt.Errorf("certificate is empty") } - if ok := trustCert.AppendCertsFromPEM([]byte(certificate)); !ok { + if cert, err := x509.ParseCertificate([]byte(certificate)); err == nil { + trustCerts = append(trustCerts, cert) + return nil + } else { return fmt.Errorf("add certificate failed") } - return nil } func ResetCertificate() { mutex.Lock() defer mutex.Unlock() - trustCert, _ = x509.SystemCertPool() + trustCerts = nil +} + +func getCertPool() *x509.CertPool { + certPool, err := x509.SystemCertPool() + if err == nil { + for _, cert := range trustCerts { + certPool.AddCert(cert) + } + } + return certPool } func verifyFingerprint(fingerprint *[32]byte) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { @@ -84,12 +96,13 @@ func GetSpecifiedFingerprintTLSConfig(tlsConfig *tls.Config, fingerprint string) } func GetGlobalTLSConfig(tlsConfig *tls.Config) *tls.Config { + certPool := getCertPool() if tlsConfig == nil { return &tls.Config{ - RootCAs: trustCert, + RootCAs: certPool, } } - tlsConfig.RootCAs = trustCert + tlsConfig.RootCAs = certPool return tlsConfig } @@ -106,12 +119,13 @@ func GetSpecifiedFingerprintXTLSConfig(tlsConfig *xtls.Config, fingerprint strin } func GetGlobalXTLSConfig(tlsConfig *xtls.Config) *xtls.Config { + certPool := getCertPool() if tlsConfig == nil { return &xtls.Config{ - RootCAs: trustCert, + RootCAs: certPool, } } - tlsConfig.RootCAs = trustCert + tlsConfig.RootCAs = certPool return tlsConfig }