forked from mirror/pages-server
152 lines
3.2 KiB
Go
152 lines
3.2 KiB
Go
package database
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/rs/zerolog/log"
|
|
|
|
"github.com/go-acme/lego/v4/certificate"
|
|
"xorm.io/xorm"
|
|
|
|
// register sql driver
|
|
_ "github.com/go-sql-driver/mysql"
|
|
_ "github.com/lib/pq"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
var _ CertDB = xDB{}
|
|
|
|
var ErrNotFound = errors.New("entry not found")
|
|
|
|
type xDB struct {
|
|
engine *xorm.Engine
|
|
}
|
|
|
|
func NewXormDB(dbType, dbConn string) (CertDB, error) {
|
|
if !supportedDriver(dbType) {
|
|
return nil, fmt.Errorf("not supported db type '%s'", dbType)
|
|
}
|
|
if dbConn == "" {
|
|
return nil, fmt.Errorf("no db connection provided")
|
|
}
|
|
|
|
e, err := xorm.NewEngine(dbType, dbConn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := e.Sync2(new(Cert)); err != nil {
|
|
return nil, fmt.Errorf("could not sync db model :%w", err)
|
|
}
|
|
|
|
return &xDB{
|
|
engine: e,
|
|
}, nil
|
|
}
|
|
|
|
func (x xDB) Close() error {
|
|
return x.engine.Close()
|
|
}
|
|
|
|
func (x xDB) Put(domain string, cert *certificate.Resource) error {
|
|
log.Trace().Str("domain", cert.Domain).Msg("inserting cert to db")
|
|
|
|
domain = integrationTestReplacements(domain)
|
|
c, err := toCert(domain, cert)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
sess := x.engine.NewSession()
|
|
if err := sess.Begin(); err != nil {
|
|
return err
|
|
}
|
|
defer sess.Close()
|
|
|
|
if exist, _ := sess.ID(c.Domain).Exist(new(Cert)); exist {
|
|
if _, err := sess.ID(c.Domain).Update(c); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
if _, err = sess.Insert(c); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return sess.Commit()
|
|
}
|
|
|
|
func (x xDB) Get(domain string) (*certificate.Resource, error) {
|
|
// handle wildcard certs
|
|
if domain[:1] == "." {
|
|
domain = "*" + domain
|
|
}
|
|
domain = integrationTestReplacements(domain)
|
|
|
|
cert := new(Cert)
|
|
log.Trace().Str("domain", domain).Msg("get cert from db")
|
|
if found, err := x.engine.ID(domain).Get(cert); err != nil {
|
|
return nil, err
|
|
} else if !found {
|
|
return nil, fmt.Errorf("%w: name='%s'", ErrNotFound, domain)
|
|
}
|
|
return cert.Raw(), nil
|
|
}
|
|
|
|
func (x xDB) Delete(domain string) error {
|
|
// handle wildcard certs
|
|
if domain[:1] == "." {
|
|
domain = "*" + domain
|
|
}
|
|
domain = integrationTestReplacements(domain)
|
|
|
|
log.Trace().Str("domain", domain).Msg("delete cert from db")
|
|
_, err := x.engine.ID(domain).Delete(new(Cert))
|
|
return err
|
|
}
|
|
|
|
// Items return al certs from db, if pageSize is 0 it does not use limit
|
|
func (x xDB) Items(page, pageSize int) ([]*Cert, error) {
|
|
// paginated return
|
|
if pageSize > 0 {
|
|
certs := make([]*Cert, 0, pageSize)
|
|
if page >= 0 {
|
|
page = 1
|
|
}
|
|
err := x.engine.Limit(pageSize, (page-1)*pageSize).Find(&certs)
|
|
return certs, err
|
|
}
|
|
|
|
// return all
|
|
certs := make([]*Cert, 0, 64)
|
|
err := x.engine.Find(&certs)
|
|
return certs, err
|
|
}
|
|
|
|
// Supported database drivers
|
|
const (
|
|
DriverSqlite = "sqlite3"
|
|
DriverMysql = "mysql"
|
|
DriverPostgres = "postgres"
|
|
)
|
|
|
|
func supportedDriver(driver string) bool {
|
|
switch driver {
|
|
case DriverMysql, DriverPostgres, DriverSqlite:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// integrationTestReplacements is needed because integration tests use a single domain cert,
|
|
// while production use a wildcard cert
|
|
// TODO: find a better way to handle this
|
|
func integrationTestReplacements(domainKey string) string {
|
|
if domainKey == "*.localhost.mock.directory" {
|
|
return "localhost.mock.directory"
|
|
}
|
|
return domainKey
|
|
}
|