diff --git a/connection.go b/connection.go index d8efdf5..6e64e2d 100644 --- a/connection.go +++ b/connection.go @@ -75,8 +75,12 @@ type SQCloud struct { ErrorMessage string } -const CompressModeNo = "NO" -const CompressModeLZ4 = "LZ4" +const SQLiteDefaultPort = 8860 + +const ( + CompressModeNo = "NO" + CompressModeLZ4 = "LZ4" +) const SQLiteCloudCA = "SQLiteCloudCA" @@ -108,7 +112,7 @@ func ParseConnectionString(ConnectionString string) (config *SQCloudConfig, err config = &SQCloudConfig{} config.Host = u.Hostname() - config.Port = 0 + config.Port = SQLiteDefaultPort config.Username = u.User.Username() config.Password, _ = u.User.Password() config.Database = strings.TrimPrefix(u.Path, "/") @@ -243,74 +247,10 @@ func (this *SQCloud) CheckConnectionParameter() error { return fmt.Errorf("Invalid hostname (%s)", this.Host) } - // ip := net.ParseIP(this.Host) - // if ip == nil { - // if _, err := net.LookupHost(this.Host); err != nil { - // return errors.New(fmt.Sprintf("Can't resolve hostname (%s)", this.Host)) - // } - // } - - if this.Port == 0 { - this.Port = 8860 - } - if this.Port < 1 || this.Port >= 0xFFFF { - return errors.New(fmt.Sprintf("Invalid Port (%d)", this.Port)) - } - - // if this.Timeout == 0 { - // this.Timeout = 10 * time.Second - // } if this.Timeout < 0 { return errors.New(fmt.Sprintf("Invalid Timeout (%s)", this.Timeout.String())) } - switch this.CompressMode { - case CompressModeNo, CompressModeLZ4: - default: - return errors.New(fmt.Sprintf("Invalid compression method (%s)", this.CompressMode)) - } - - if this.Secure { - var pool *x509.CertPool = nil - pem := []byte{} - - switch _, _, trimmed := ParseTlsString(this.Pem); trimmed { - case "": - break - case SQLiteCloudCA: - pem = []byte(sqliteCloudCAPEM) - default: - // check if it is a filepath - _, err := os.Stat(trimmed) - if os.IsNotExist(err) { - // not a filepath, use the string as a pem string - pem = []byte(trimmed) - } else { - // its a file, read its content into the pem string - switch bytes, err := os.ReadFile(trimmed); { - case err != nil: - return errors.New(fmt.Sprintf("Could not open PEM file in '%s'", trimmed)) - default: - pem = bytes - } - } - } - - if len(pem) > 0 { - pool = x509.NewCertPool() - - if !pool.AppendCertsFromPEM(pem) { - return errors.New(fmt.Sprintf("Could not append certs from PEM")) - } - } - - this.cert = &tls.Config{ - RootCAs: pool, - InsecureSkipVerify: this.TlsInsecureSkipVerify, - MinVersion: tls.VersionTLS12, - } - } - return nil } @@ -353,12 +293,11 @@ func Connect(ConnectionString string) (*SQCloud, error) { func (this *SQCloud) Connect() error { this.reset() // also closes an open connection - switch err := this.CheckConnectionParameter(); { - case err != nil: + if err := this.CheckConnectionParameter(); err != nil { return err - default: - return this.reconnect() } + + return this.reconnect() } // reconnect closes and then reopens a connection to the SQLite Cloud database server. @@ -369,6 +308,14 @@ func (this *SQCloud) reconnect() error { this.resetError() + if this.Secure { + cert, err := getTlsConfig(&this.SQCloudConfig) + if err != nil { + return err + } + this.cert = cert + } + var dialer = net.Dialer{} dialer.Timeout = this.Timeout dialer.DualStack = true @@ -437,6 +384,47 @@ func (this *SQCloud) Close() error { return nil } +func getTlsConfig(config *SQCloudConfig) (*tls.Config, error) { + var pool *x509.CertPool = nil + pem := []byte{} + + switch _, _, trimmed := ParseTlsString(config.Pem); trimmed { + case "": + break + case SQLiteCloudCA: + pem = []byte(sqliteCloudCAPEM) + default: + // check if it is a filepath + _, err := os.Stat(trimmed) + if os.IsNotExist(err) { + // not a filepath, use the string as a pem string + pem = []byte(trimmed) + } else { + // its a file, read its content into the pem string + switch bytes, err := os.ReadFile(trimmed); { + case err != nil: + return nil, fmt.Errorf("could not open PEM file in '%s'", trimmed) + default: + pem = bytes + } + } + } + + if len(pem) > 0 { + pool = x509.NewCertPool() + + if !pool.AppendCertsFromPEM(pem) { + return nil, fmt.Errorf("could not append certs from PEM") + } + } + + return &tls.Config{ + RootCAs: pool, + InsecureSkipVerify: config.TlsInsecureSkipVerify, + MinVersion: tls.VersionTLS12, + }, nil +} + func connectionCommands(config SQCloudConfig) (string, []interface{}) { buffer := "" args := []interface{}{} diff --git a/test/compress_test.go b/test/compress_test.go index 44922f4..40d377d 100644 --- a/test/compress_test.go +++ b/test/compress_test.go @@ -28,8 +28,6 @@ import ( ) const testDbnameCompress = "test-gosdk-compress-db.sqlite" -const testCompressKey = "compress" -const testCompressValue = "LZ4" func TestCompress(t *testing.T) { connectionString, _ := os.LookupEnv("SQLITE_CONNECTION_STRING") @@ -38,7 +36,7 @@ func TestCompress(t *testing.T) { url, err := url.Parse(connectionString) values := url.Query() - values.Add(testCompressKey, testCompressValue) + values.Add("compress", sqlitecloud.CompressModeLZ4) url.RawQuery = values.Encode() connstring := url.String()