diff options
Diffstat (limited to 'internal')
-rw-r--r-- | internal/crons/main.go | 88 | ||||
-rw-r--r-- | internal/database/migrations/001_init.down.sql | 4 | ||||
-rw-r--r-- | internal/database/migrations/001_init.up.sql | 25 | ||||
-rw-r--r-- | internal/database/migrations/002_defaults.down.sql | 5 | ||||
-rw-r--r-- | internal/database/migrations/002_defaults.up.sql | 11 | ||||
-rw-r--r-- | internal/database/migrations/init.go | 3 | ||||
-rw-r--r-- | internal/database/migrations/migrations.bindata.go | 267 | ||||
-rw-r--r-- | internal/database/repos/defaults.go | 25 | ||||
-rw-r--r-- | internal/database/repos/main.go | 19 | ||||
-rw-r--r-- | internal/database/sql/main.go | 88 | ||||
-rw-r--r-- | internal/handlers/auth.go | 162 | ||||
-rw-r--r-- | internal/handlers/main.go | 64 | ||||
-rw-r--r-- | internal/handlers/middleware.go | 75 | ||||
-rw-r--r-- | internal/models/auth.go | 24 | ||||
-rw-r--r-- | internal/models/models.go | 45 | ||||
-rw-r--r-- | internal/server/main.go | 59 | ||||
-rw-r--r-- | internal/server/router.go | 31 |
17 files changed, 995 insertions, 0 deletions
diff --git a/internal/crons/main.go b/internal/crons/main.go new file mode 100644 index 0000000..6649992 --- /dev/null +++ b/internal/crons/main.go @@ -0,0 +1,88 @@ +package crons + +import ( + "demoon/internal/database/repos" + "context" + "os" + "strconv" + "time" + + "log/slog" +) + +var ( + log = slog.New(slog.NewJSONHandler(os.Stdout, + &slog.HandlerOptions{Level: slog.LevelDebug})) +) + +const ( + CheckBurnTimeKey = "check_burn_time_key" +) + +type Cron struct { + ctx context.Context + repo repos.FullRepo +} + +func NewCron( + ctx context.Context, repo repos.FullRepo, +) *Cron { + return &Cron{ + ctx: ctx, + repo: repo, + } +} + +func (c *Cron) UpdateDefaultsFloat64(defaults map[string]float64) (map[string]float64, error) { + dm, err := c.repo.DBGetDefaultsMap() + if err != nil { + return defaults, err + } + for k, _ := range defaults { + sosValue, ok := dm[k] + if ok { + value, err := strconv.ParseFloat(sosValue, 64) + if err != nil { + // log err + continue + } + defaults[k] = value + } + } + return defaults, nil +} + +func (c *Cron) StartCronJobs() { + // check system_options for time + defaults := map[string]float64{ + CheckBurnTimeKey: 5.0, + } + defaults, err := c.UpdateDefaultsFloat64(defaults) + if err != nil { + panic(err) + } + go c.BurnTicker( + time.Minute * time.Duration(defaults[CheckBurnTimeKey])) +} + +func (c *Cron) BurnTicker(interval time.Duration) { + funcname := "BurnTicker" + ticker := time.NewTicker(interval) + for { + select { + case <-c.ctx.Done(): + log.Info("cron stopped by ctx.Done", "func", funcname) + return + case <-ticker.C: + c.BurnTimeUpdate() + } + } +} + +func (c *Cron) BurnTimeUpdate() { + // funcname := "BurnTimeUpdate" + // get all user scores + // if burn time < now + // halv the score + // move the burn time to 24h from now (using same hours, minutes, seconds) +} diff --git a/internal/database/migrations/001_init.down.sql b/internal/database/migrations/001_init.down.sql new file mode 100644 index 0000000..41cc5c9 --- /dev/null +++ b/internal/database/migrations/001_init.down.sql @@ -0,0 +1,4 @@ +BEGIN TRANSACTION; +DROP SCHEMA public CASCADE; +CREATE SCHEMA public; +COMMIT; diff --git a/internal/database/migrations/001_init.up.sql b/internal/database/migrations/001_init.up.sql new file mode 100644 index 0000000..141a045 --- /dev/null +++ b/internal/database/migrations/001_init.up.sql @@ -0,0 +1,25 @@ +BEGIN TRANSACTION; +CREATE TABLE IF NOT EXISTS user_score ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT UNIQUE NOT NULL, + password TEXT NOT NULL, + burn_time TIMESTAMP NOT NULL, + score SMALLINT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS action( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + magnitude SMALLINT NOT NULL DEFAULT 1, + repeatable BOOLEAN NOT NULL DEFAULT FALSE, + type TEXT NOT NULL, + done BOOLEAN NOT NULL DEFAULT FALSE, + username TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + UNIQUE(username, name), + CONSTRAINT fk_user_score + FOREIGN KEY(username) + REFERENCES user_score(username) +); +COMMIT; diff --git a/internal/database/migrations/002_defaults.down.sql b/internal/database/migrations/002_defaults.down.sql new file mode 100644 index 0000000..3e38fa2 --- /dev/null +++ b/internal/database/migrations/002_defaults.down.sql @@ -0,0 +1,5 @@ +BEGIN TRANSACTION; + +DROP TABLE defaults; + +COMMIT; diff --git a/internal/database/migrations/002_defaults.up.sql b/internal/database/migrations/002_defaults.up.sql new file mode 100644 index 0000000..4bf42b7 --- /dev/null +++ b/internal/database/migrations/002_defaults.up.sql @@ -0,0 +1,11 @@ +BEGIN TRANSACTION; + +CREATE TABLE IF NOT EXISTS defaults ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + key TEXT UNIQUE NOT NULL, + value TEXT NOT NULL, + description TEXT NOT NULL DEFAULT '', + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +COMMIT; diff --git a/internal/database/migrations/init.go b/internal/database/migrations/init.go new file mode 100644 index 0000000..2b5a212 --- /dev/null +++ b/internal/database/migrations/init.go @@ -0,0 +1,3 @@ +package migrations + +//go:generate go-bindata -o ./migrations.bindata.go -pkg migrations -ignore=\\*.go ./... diff --git a/internal/database/migrations/migrations.bindata.go b/internal/database/migrations/migrations.bindata.go new file mode 100644 index 0000000..a3276ee --- /dev/null +++ b/internal/database/migrations/migrations.bindata.go @@ -0,0 +1,267 @@ +// Code generated for package migrations by go-bindata DO NOT EDIT. (@generated) +// sources: +// 001_init.down.sql +// 001_init.up.sql +package migrations + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + "time" +) + +func bindataRead(data []byte, name string) ([]byte, error) { + gz, err := gzip.NewReader(bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("Read %q: %v", name, err) + } + + var buf bytes.Buffer + _, err = io.Copy(&buf, gz) + clErr := gz.Close() + + if err != nil { + return nil, fmt.Errorf("Read %q: %v", name, err) + } + if clErr != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +type asset struct { + bytes []byte + info os.FileInfo +} + +type bindataFileInfo struct { + name string + size int64 + mode os.FileMode + modTime time.Time +} + +// Name return file name +func (fi bindataFileInfo) Name() string { + return fi.name +} + +// Size return file size +func (fi bindataFileInfo) Size() int64 { + return fi.size +} + +// Mode return file mode +func (fi bindataFileInfo) Mode() os.FileMode { + return fi.mode +} + +// Mode return file modify time +func (fi bindataFileInfo) ModTime() time.Time { + return fi.modTime +} + +// IsDir return file whether a directory +func (fi bindataFileInfo) IsDir() bool { + return fi.mode&os.ModeDir != 0 +} + +// Sys return file is sys mode +func (fi bindataFileInfo) Sys() interface{} { + return nil +} + +var __001_initDownSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x72\x72\x75\xf7\xf4\xb3\xe6\x72\x09\xf2\x0f\x50\x08\x76\xf6\x70\xf5\x75\x54\x28\x28\x4d\xca\xc9\x4c\x56\x70\x76\x0c\x76\x76\x74\x71\xb5\xe6\x72\x0e\x72\x75\x0c\x71\x45\x95\xb5\xe6\x72\xf6\xf7\xf5\xf5\x0c\xb1\xe6\x02\x04\x00\x00\xff\xff\x47\x78\xff\x61\x41\x00\x00\x00") + +func _001_initDownSqlBytes() ([]byte, error) { + return bindataRead( + __001_initDownSql, + "001_init.down.sql", + ) +} + +func _001_initDownSql() (*asset, error) { + bytes, err := _001_initDownSqlBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "001_init.down.sql", size: 65, mode: os.FileMode(420), modTime: time.Unix(1712559447, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +var __001_initUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x9c\x92\x41\x6f\x82\x40\x10\x85\xef\xfc\x8a\x77\x53\x53\x2f\x9e\x3d\x2d\x38\x98\x4d\x97\xdd\x16\x96\xb4\x9e\xc8\x2a\xdb\x86\x54\xd1\xe0\xda\xc4\x7f\xdf\xe0\x5a\x28\xa9\x49\x1b\xcf\xf3\xcd\x9b\xf7\x66\x26\xa4\x25\x97\xf3\x20\x4a\x89\x69\x82\x66\xa1\x20\x9c\x8e\xb6\x29\x8e\x9b\x7d\x63\x31\x0e\x00\xa0\x2a\xc1\xa5\xc6\x92\x24\xa5\x4c\xd3\x02\xe1\x0a\x0b\x8a\x59\x2e\x34\x58\x06\xbe\x20\xa9\xb9\x5e\x4d\x2f\x70\xdb\x5d\x9b\x9d\x85\xa6\x57\x8d\x5c\xf2\xe7\x9c\x20\x95\x86\xcc\x85\xf0\xc8\xfa\xd4\xd4\x85\xab\x5a\x86\x27\x94\x69\x96\x3c\x75\x44\x27\x2c\xd5\xcb\x78\x82\x07\x54\xb5\xb3\xcd\xa7\xd9\x62\x34\x43\x69\xce\x23\x2f\xe1\xed\x65\x09\x13\xa2\xb5\x36\xd4\xdf\x34\xd6\x38\x5b\x16\xc6\xfd\x39\x20\x98\xcc\x83\x61\x7a\xb3\x71\xd5\xbe\xbe\x27\x79\x9f\x7a\x68\x67\x67\xde\xeb\xca\x9d\xca\x1b\x7e\x3b\xb1\x99\x47\x1b\x7b\xb0\xc6\x99\xf5\xd6\x22\x54\x4a\x10\x93\xbf\xd1\x98\x89\x8c\x3c\xee\xce\x87\x9b\x13\xcb\x7d\xfd\x3f\x81\xe1\xb1\xee\xde\xa2\x6f\xf0\xb7\x1e\x7f\x6b\x4e\x2f\x0b\xb9\xd6\x22\x25\x33\x9d\xb2\x36\xfb\xdb\x47\xd1\x7f\xd8\xa5\x08\xc4\x2a\x25\xbe\x94\x78\xa4\x55\xd7\x3f\xc1\xb5\x98\x52\x4c\x29\xc9\x88\xb2\x1f\xaf\xd9\x63\xed\x09\x23\x95\x24\x5c\xcf\x83\xaf\x00\x00\x00\xff\xff\x9f\xe9\x59\xfc\xcf\x02\x00\x00") + +func _001_initUpSqlBytes() ([]byte, error) { + return bindataRead( + __001_initUpSql, + "001_init.up.sql", + ) +} + +func _001_initUpSql() (*asset, error) { + bytes, err := _001_initUpSqlBytes() + if err != nil { + return nil, err + } + + info := bindataFileInfo{name: "001_init.up.sql", size: 719, mode: os.FileMode(420), modTime: time.Unix(1712558242, 0)} + a := &asset{bytes: bytes, info: info} + return a, nil +} + +// Asset loads and returns the asset for the given name. +// It returns an error if the asset could not be found or +// could not be loaded. +func Asset(name string) ([]byte, error) { + cannonicalName := strings.Replace(name, "\\", "/", -1) + if f, ok := _bindata[cannonicalName]; ok { + a, err := f() + if err != nil { + return nil, fmt.Errorf("Asset %s can't read by error: %v", name, err) + } + return a.bytes, nil + } + return nil, fmt.Errorf("Asset %s not found", name) +} + +// MustAsset is like Asset but panics when Asset would return an error. +// It simplifies safe initialization of global variables. +func MustAsset(name string) []byte { + a, err := Asset(name) + if err != nil { + panic("asset: Asset(" + name + "): " + err.Error()) + } + + return a +} + +// AssetInfo loads and returns the asset info for the given name. +// It returns an error if the asset could not be found or +// could not be loaded. +func AssetInfo(name string) (os.FileInfo, error) { + cannonicalName := strings.Replace(name, "\\", "/", -1) + if f, ok := _bindata[cannonicalName]; ok { + a, err := f() + if err != nil { + return nil, fmt.Errorf("AssetInfo %s can't read by error: %v", name, err) + } + return a.info, nil + } + return nil, fmt.Errorf("AssetInfo %s not found", name) +} + +// AssetNames returns the names of the assets. +func AssetNames() []string { + names := make([]string, 0, len(_bindata)) + for name := range _bindata { + names = append(names, name) + } + return names +} + +// _bindata is a table, holding each asset generator, mapped to its name. +var _bindata = map[string]func() (*asset, error){ + "001_init.down.sql": _001_initDownSql, + "001_init.up.sql": _001_initUpSql, +} + +// AssetDir returns the file names below a certain +// directory embedded in the file by go-bindata. +// For example if you run go-bindata on data/... and data contains the +// following hierarchy: +// data/ +// foo.txt +// img/ +// a.png +// b.png +// then AssetDir("data") would return []string{"foo.txt", "img"} +// AssetDir("data/img") would return []string{"a.png", "b.png"} +// AssetDir("foo.txt") and AssetDir("notexist") would return an error +// AssetDir("") will return []string{"data"}. +func AssetDir(name string) ([]string, error) { + node := _bintree + if len(name) != 0 { + cannonicalName := strings.Replace(name, "\\", "/", -1) + pathList := strings.Split(cannonicalName, "/") + for _, p := range pathList { + node = node.Children[p] + if node == nil { + return nil, fmt.Errorf("Asset %s not found", name) + } + } + } + if node.Func != nil { + return nil, fmt.Errorf("Asset %s not found", name) + } + rv := make([]string, 0, len(node.Children)) + for childName := range node.Children { + rv = append(rv, childName) + } + return rv, nil +} + +type bintree struct { + Func func() (*asset, error) + Children map[string]*bintree +} + +var _bintree = &bintree{nil, map[string]*bintree{ + "001_init.down.sql": &bintree{_001_initDownSql, map[string]*bintree{}}, + "001_init.up.sql": &bintree{_001_initUpSql, map[string]*bintree{}}, +}} + +// RestoreAsset restores an asset under the given directory +func RestoreAsset(dir, name string) error { + data, err := Asset(name) + if err != nil { + return err + } + info, err := AssetInfo(name) + if err != nil { + return err + } + err = os.MkdirAll(_filePath(dir, filepath.Dir(name)), os.FileMode(0755)) + if err != nil { + return err + } + err = ioutil.WriteFile(_filePath(dir, name), data, info.Mode()) + if err != nil { + return err + } + err = os.Chtimes(_filePath(dir, name), info.ModTime(), info.ModTime()) + if err != nil { + return err + } + return nil +} + +// RestoreAssets restores an asset under the given directory recursively +func RestoreAssets(dir, name string) error { + children, err := AssetDir(name) + // File + if err != nil { + return RestoreAsset(dir, name) + } + // Dir + for _, child := range children { + err = RestoreAssets(dir, filepath.Join(name, child)) + if err != nil { + return err + } + } + return nil +} + +func _filePath(dir, name string) string { + cannonicalName := strings.Replace(name, "\\", "/", -1) + return filepath.Join(append([]string{dir}, strings.Split(cannonicalName, "/")...)...) +} diff --git a/internal/database/repos/defaults.go b/internal/database/repos/defaults.go new file mode 100644 index 0000000..e0b5e52 --- /dev/null +++ b/internal/database/repos/defaults.go @@ -0,0 +1,25 @@ +package repos + +type DefaultsRepo interface { + DBGetDefaultsMap() (map[string]string, error) +} + +func (p *Provider) DBGetDefaultsMap() (map[string]string, error) { + rows, err := p.db.Queryx(`SELECT key, value + FROM defaults; + `) + if err != nil { + return nil, err + } + res := make(map[string]string) + for rows.Next() { + keyval, err := rows.SliceScan() + if err != nil { + return nil, err + } + key := keyval[0].(string) + value := keyval[1].(string) + res[key] = value + } + return res, nil +} diff --git a/internal/database/repos/main.go b/internal/database/repos/main.go new file mode 100644 index 0000000..e57855f --- /dev/null +++ b/internal/database/repos/main.go @@ -0,0 +1,19 @@ +package repos + +import ( + "github.com/jmoiron/sqlx" +) + +type FullRepo interface { + DefaultsRepo +} + +type Provider struct { + db *sqlx.DB +} + +func NewProvider(conn *sqlx.DB) *Provider { + return &Provider{ + db: conn, + } +} diff --git a/internal/database/sql/main.go b/internal/database/sql/main.go new file mode 100644 index 0000000..5a523f6 --- /dev/null +++ b/internal/database/sql/main.go @@ -0,0 +1,88 @@ +package database + +import ( + "os" + "time" + + "log/slog" + + "github.com/jmoiron/sqlx" + _ "github.com/mattn/go-sqlite3" + "github.com/pkg/errors" +) + +var ( + log = slog.New(slog.NewJSONHandler(os.Stdout, nil)) + dbDriver = "sqlite3" +) + +type DB struct { + Conn *sqlx.DB + URI string +} + +func (d *DB) CloseAll() error { + for _, conn := range []*sqlx.DB{d.Conn} { + if err := closeConn(conn); err != nil { + return err + } + } + return nil +} + +func closeConn(conn *sqlx.DB) error { + return conn.Close() +} + +func Init(DBURI string) (*DB, error) { + var result DB + var err error + result.Conn, err = openDBConnection(DBURI, dbDriver) + if err != nil { + return nil, err + } + result.URI = DBURI + if err := testConnection(result.Conn); err != nil { + return nil, err + } + return &result, nil +} + +func openDBConnection(dbURI, driver string) (*sqlx.DB, error) { + conn, err := sqlx.Open(driver, dbURI) + if err != nil { + return nil, err + } + return conn, nil +} + +func testConnection(conn *sqlx.DB) error { + err := conn.Ping() + if err != nil { + return errors.Wrap(err, "can't ping database") + } + return nil +} + +func (d *DB) PingRoutine(interval time.Duration) { + ticker := time.NewTicker(interval) + done := make(chan bool) + for { + select { + case <-done: + return + case t := <-ticker.C: + if err := testConnection(d.Conn); err != nil { + log.Error("failed to ping postrges db", "error", err, "ping_at", t) + // reconnect + if err := closeConn(d.Conn); err != nil { + log.Error("failed to close db connection", "error", err, "ping_at", t) + } + d.Conn, err = openDBConnection(d.URI, dbDriver) + if err != nil { + log.Error("failed to reconnect", "error", err, "ping_at", t) + } + } + } + } +} diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go new file mode 100644 index 0000000..d0ccdff --- /dev/null +++ b/internal/handlers/auth.go @@ -0,0 +1,162 @@ +package handlers + +import ( + "crypto/hmac" + "crypto/sha256" + "demoon/internal/models" + "demoon/pkg/utils" + "encoding/base64" + "encoding/json" + "html/template" + "net/http" + "strings" + "time" + + "golang.org/x/crypto/bcrypt" +) + +func abortWithError(w http.ResponseWriter, msg string) { + tmpl := template.Must(template.ParseGlob("components/*.html")) + tmpl.ExecuteTemplate(w, "error", msg) +} + +func (h *Handlers) HandleSignup(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + username := r.PostFormValue("username") + if username == "" { + msg := "username not provided" + h.log.Error(msg) + abortWithError(w, msg) + return + } + password := r.PostFormValue("password") + if password == "" { + msg := "password not provided" + h.log.Error(msg) + abortWithError(w, msg) + return + } + // make sure username does not exists + cleanName := utils.RemoveSpacesFromStr(username) + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), 8) + // create user in db + now := time.Now() + nextMidnight := time.Date(now.Year(), now.Month(), now.Day(), + 0, 0, 0, 0, time.UTC).Add(time.Hour * 24) + newUser := &models.UserScore{ + Username: cleanName, Password: string(hashedPassword), + BurnTime: nextMidnight, CreatedAt: now, + } + // login user + cookie, err := h.makeCookie(cleanName, r.RemoteAddr) + if err != nil { + h.log.Error("failed to login", "error", err) + abortWithError(w, err.Error()) + return + } + http.SetCookie(w, cookie) + // http.Redirect(w, r, "/", 302) + tmpl, err := template.ParseGlob("components/*.html") + if err != nil { + abortWithError(w, err.Error()) + return + } + tmpl.ExecuteTemplate(w, "main", newUser) +} + +func (h *Handlers) HandleLogin(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + username := r.PostFormValue("username") + if username == "" { + msg := "username not provided" + h.log.Error(msg) + abortWithError(w, msg) + return + } + password := r.PostFormValue("password") + if password == "" { + msg := "password not provided" + h.log.Error(msg) + abortWithError(w, msg) + return + } + cleanName := utils.RemoveSpacesFromStr(username) + tmpl, err := template.ParseGlob("components/*.html") + if err != nil { + abortWithError(w, err.Error()) + return + } + cookie, err := h.makeCookie(cleanName, r.RemoteAddr) + if err != nil { + h.log.Error("failed to login", "error", err) + abortWithError(w, err.Error()) + return + } + http.SetCookie(w, cookie) + tmpl.ExecuteTemplate(w, "main", nil) +} + +func (h *Handlers) makeCookie(username string, remote string) (*http.Cookie, error) { + // secret + // Create a new random session token + // sessionToken := xid.New().String() + sessionToken := "token" + expiresAt := time.Now().Add(time.Duration(h.cfg.SessionLifetime) * time.Second) + // Set the token in the session map, along with the session information + session := &models.Session{ + Username: username, + Expiry: expiresAt, + } + cookieName := "session_token" + // hmac to protect cookies + hm := hmac.New(sha256.New, []byte(h.cfg.CookieSecret)) + hm.Write([]byte(cookieName)) + hm.Write([]byte(sessionToken)) + signature := hm.Sum(nil) + // b64 enc to avoid non-ascii + cookieValue := base64.URLEncoding.EncodeToString([]byte( + string(signature) + sessionToken)) + cookie := &http.Cookie{ + Name: cookieName, + Value: cookieValue, + Secure: true, + HttpOnly: true, + SameSite: http.SameSiteNoneMode, + Domain: h.cfg.ServerConfig.Host, + } + h.log.Info("check remote addr for cookie set", + "remote", remote, "session", session) + if strings.Contains(remote, "192.168.0") { + // no idea what is going on + cookie.Domain = "192.168.0.101" + } + // set ctx? + // set user in session + if err := h.cacheSetSession(sessionToken, session); err != nil { + return nil, err + } + return cookie, nil +} + +func (h *Handlers) cacheGetSession(key string) (*models.Session, error) { + userSessionB, err := h.mc.Get(key) + if err != nil { + return nil, err + } + var us *models.Session + if err := json.Unmarshal(userSessionB, &us); err != nil { + return nil, err + } + return us, nil +} + +func (h *Handlers) cacheSetSession(key string, session *models.Session) error { + sesb, err := json.Marshal(session) + if err != nil { + return err + } + h.mc.Set(key, sesb) + // expire in 10 min + h.mc.Expire(key, 10*60) + return nil +} diff --git a/internal/handlers/main.go b/internal/handlers/main.go new file mode 100644 index 0000000..960b26d --- /dev/null +++ b/internal/handlers/main.go @@ -0,0 +1,64 @@ +package handlers + +import ( + "demoon/config" + "demoon/internal/database/repos" + "demoon/internal/models" + "demoon/pkg/cache" + "html/template" + "log/slog" + "net/http" + "os" +) + +var defUS = models.UserScore{} + +// Handlers structure +type Handlers struct { + cfg config.Config + log *slog.Logger + repo repos.FullRepo + mc cache.Cache +} + +// NewHandlers constructor +func NewHandlers( + cfg config.Config, l *slog.Logger, repo repos.FullRepo, +) *Handlers { + if l == nil { + l = slog.New(slog.NewJSONHandler(os.Stdout, nil)) + } + h := &Handlers{ + cfg: cfg, + log: l, + repo: repo, + mc: cache.MemCache, + } + return h +} + +func (h *Handlers) Ping(w http.ResponseWriter, r *http.Request) { + h.log.Info("got ping request") + w.Write([]byte("pong")) +} + +func (h *Handlers) MainPage(w http.ResponseWriter, r *http.Request) { + tmpl, err := template.ParseGlob("components/*.html") + if err != nil { + abortWithError(w, err.Error()) + return + } + // get recommendations + usernameRaw := r.Context().Value("username") + h.log.Info("got mainpage request", "username", usernameRaw) + if usernameRaw == nil { + tmpl.ExecuteTemplate(w, "main", defUS) + return + } + username := usernameRaw.(string) + if username == "" { + tmpl.ExecuteTemplate(w, "main", defUS) + return + } + tmpl.ExecuteTemplate(w, "main", nil) +} diff --git a/internal/handlers/middleware.go b/internal/handlers/middleware.go new file mode 100644 index 0000000..8b871a2 --- /dev/null +++ b/internal/handlers/middleware.go @@ -0,0 +1,75 @@ +package handlers + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "errors" + "net/http" +) + +func (h *Handlers) GetSession(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookieName := "session_token" + sessionCookie, err := r.Cookie(cookieName) + if err != nil { + msg := "auth failed; failed to get session token from cookies" + h.log.Debug(msg, "error", err) + next.ServeHTTP(w, r) + return + } + cookieValueB, err := base64.URLEncoding. + DecodeString(sessionCookie.Value) + if err != nil { + msg := "auth failed; failed to decode b64 cookie" + h.log.Debug(msg, "error", err) + next.ServeHTTP(w, r) + return + } + cookieValue := string(cookieValueB) + if len(cookieValue) < sha256.Size { + h.log.Warn("small cookie", "size", len(cookieValue)) + next.ServeHTTP(w, r) + return + } + // Split apart the signature and original cookie value. + signature := cookieValue[:sha256.Size] + sessionToken := cookieValue[sha256.Size:] + //verify signature + mac := hmac.New(sha256.New, []byte(h.cfg.CookieSecret)) + mac.Write([]byte(cookieName)) + mac.Write([]byte(sessionToken)) + expectedSignature := mac.Sum(nil) + if !hmac.Equal([]byte(signature), expectedSignature) { + h.log.Debug("cookie with an invalid sign") + next.ServeHTTP(w, r) + return + } + userSession, err := h.cacheGetSession(sessionToken) + if err != nil { + msg := "auth failed; session does not exists" + err = errors.New(msg) + h.log.Debug(msg, "error", err) + next.ServeHTTP(w, r) + return + } + if userSession.IsExpired() { + h.mc.RemoveKey(sessionToken) + msg := "session is expired" + h.log.Debug(msg, "error", err, "token", sessionToken) + next.ServeHTTP(w, r) + return + } + ctx := context.WithValue(r.Context(), + "username", userSession.Username) + if err := h.cacheSetSession(sessionToken, + userSession); err != nil { + msg := "failed to marshal user session" + h.log.Warn(msg, "error", err) + next.ServeHTTP(w, r) + return + } + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/internal/models/auth.go b/internal/models/auth.go new file mode 100644 index 0000000..9964cd5 --- /dev/null +++ b/internal/models/auth.go @@ -0,0 +1,24 @@ +package models + +import ( + "time" +) + +// each session contains the username of the user and the time at which it expires +type Session struct { + Username string + Expiry time.Time +} + +// we'll use this method later to determine if the session has expired +func (s Session) IsExpired() bool { + return s.Expiry.Before(time.Now()) +} + +func ListUsernames(ss map[string]*Session) []string { + resp := make([]string, 0, len(ss)) + for _, s := range ss { + resp = append(resp, s.Username) + } + return resp +} diff --git a/internal/models/models.go b/internal/models/models.go new file mode 100644 index 0000000..5bc120b --- /dev/null +++ b/internal/models/models.go @@ -0,0 +1,45 @@ +package models + +import "time" + +type ActionType string + +const ( + ActionTypePlus ActionType = "ActionTypePlus" + ActionTypeMinus ActionType = "ActionTypeMinus" +) + +type ( + UserScore struct { + ID uint32 `db:"id"` + Username string `db:"username"` + Password string `db:"password"` + Actions []Action + Recommendations []Action + BurnTime time.Time `db:"burn_time"` + Score int8 `db:"score"` + CreatedAt time.Time `db:"created_at"` + } + Action struct { + ID uint32 `db:"id"` + Name string `db:"name"` + Magnitude uint8 `db:"magnitude"` + Repeatable bool `db:"repeatable"` + Type ActionType `db:"type"` + Done bool `db:"done"` + Username string `db:"username"` + CreatedAt time.Time `db:"created_at"` + } +) + +func (us *UserScore) UpdateScore(act *Action) { + switch act.Type { + case ActionTypePlus: + us.Score += int8(act.Magnitude) + if !act.Repeatable { + act.Done = true + } + case ActionTypeMinus: + us.Score -= int8(act.Magnitude) + } +} diff --git a/internal/server/main.go b/internal/server/main.go new file mode 100644 index 0000000..2c3c8d6 --- /dev/null +++ b/internal/server/main.go @@ -0,0 +1,59 @@ +package server + +import ( + "demoon/config" + "demoon/internal/database/repos" + "demoon/internal/handlers" + + "context" + "log/slog" + "os" + "os/signal" + "syscall" +) + +// Server interface +type Server interface { + Listen() +} + +type server struct { + config config.Config + actions *handlers.Handlers + ctx context.Context + close context.CancelFunc +} + +func (srv *server) stopOnSignal(close context.CancelFunc) { + // listen for termination signals + sigc := make(chan os.Signal, 1) + signal.Notify(sigc, os.Interrupt, syscall.SIGINT) + signal.Notify(sigc, os.Interrupt, syscall.SIGTERM) + sig := <-sigc + + log := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + log.Info("Shutting down services", + "section", "server", + "app_event", "terminate", + "signal", sig.String()) + close() + os.Exit(0) +} + +func NewServer(cfg config.Config, log *slog.Logger, repo repos.FullRepo) Server { + ctx, close := context.WithCancel(context.Background()) + actions := handlers.NewHandlers(cfg, log, repo) + return &server{ + config: cfg, + actions: actions, + ctx: ctx, + close: close, + } +} + +// Listen for new events that affect the market and process them +func (srv *server) Listen() { + // start the http server + go srv.ListenToRequests() + srv.stopOnSignal(srv.close) +} diff --git a/internal/server/router.go b/internal/server/router.go new file mode 100644 index 0000000..989766c --- /dev/null +++ b/internal/server/router.go @@ -0,0 +1,31 @@ +package server + +import ( + "fmt" + "net/http" + "time" +) + +func (srv *server) ListenToRequests() { + h := srv.actions + mux := http.NewServeMux() + server := &http.Server{ + Addr: fmt.Sprintf("localhost:%s", srv.config.ServerConfig.Port), + Handler: h.GetSession(mux), + ReadTimeout: time.Second * 5, + WriteTimeout: time.Second * 5, + } + + fs := http.FileServer(http.Dir("assets/")) + mux.Handle("GET /assets/", http.StripPrefix("/assets/", fs)) + + mux.HandleFunc("GET /ping", h.Ping) + mux.HandleFunc("GET /", h.MainPage) + mux.HandleFunc("POST /login", h.HandleLogin) + mux.HandleFunc("POST /signup", h.HandleSignup) + + // ====== elements ====== + + fmt.Println("Listening", "addr", server.Addr) + server.ListenAndServe() +} |