summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/crons/main.go88
-rw-r--r--internal/database/migrations/001_init.down.sql4
-rw-r--r--internal/database/migrations/001_init.up.sql25
-rw-r--r--internal/database/migrations/002_defaults.down.sql5
-rw-r--r--internal/database/migrations/002_defaults.up.sql11
-rw-r--r--internal/database/migrations/init.go3
-rw-r--r--internal/database/migrations/migrations.bindata.go267
-rw-r--r--internal/database/repos/defaults.go25
-rw-r--r--internal/database/repos/main.go19
-rw-r--r--internal/database/sql/main.go88
-rw-r--r--internal/handlers/auth.go162
-rw-r--r--internal/handlers/main.go64
-rw-r--r--internal/handlers/middleware.go75
-rw-r--r--internal/models/auth.go24
-rw-r--r--internal/models/models.go45
-rw-r--r--internal/server/main.go59
-rw-r--r--internal/server/router.go31
17 files changed, 995 insertions, 0 deletions
diff --git a/internal/crons/main.go b/internal/crons/main.go
new file mode 100644
index 0000000..6649992
--- /dev/null
+++ b/internal/crons/main.go
@@ -0,0 +1,88 @@
+package crons
+
+import (
+ "demoon/internal/database/repos"
+ "context"
+ "os"
+ "strconv"
+ "time"
+
+ "log/slog"
+)
+
+var (
+ log = slog.New(slog.NewJSONHandler(os.Stdout,
+ &slog.HandlerOptions{Level: slog.LevelDebug}))
+)
+
+const (
+ CheckBurnTimeKey = "check_burn_time_key"
+)
+
+type Cron struct {
+ ctx context.Context
+ repo repos.FullRepo
+}
+
+func NewCron(
+ ctx context.Context, repo repos.FullRepo,
+) *Cron {
+ return &Cron{
+ ctx: ctx,
+ repo: repo,
+ }
+}
+
+func (c *Cron) UpdateDefaultsFloat64(defaults map[string]float64) (map[string]float64, error) {
+ dm, err := c.repo.DBGetDefaultsMap()
+ if err != nil {
+ return defaults, err
+ }
+ for k, _ := range defaults {
+ sosValue, ok := dm[k]
+ if ok {
+ value, err := strconv.ParseFloat(sosValue, 64)
+ if err != nil {
+ // log err
+ continue
+ }
+ defaults[k] = value
+ }
+ }
+ return defaults, nil
+}
+
+func (c *Cron) StartCronJobs() {
+ // check system_options for time
+ defaults := map[string]float64{
+ CheckBurnTimeKey: 5.0,
+ }
+ defaults, err := c.UpdateDefaultsFloat64(defaults)
+ if err != nil {
+ panic(err)
+ }
+ go c.BurnTicker(
+ time.Minute * time.Duration(defaults[CheckBurnTimeKey]))
+}
+
+func (c *Cron) BurnTicker(interval time.Duration) {
+ funcname := "BurnTicker"
+ ticker := time.NewTicker(interval)
+ for {
+ select {
+ case <-c.ctx.Done():
+ log.Info("cron stopped by ctx.Done", "func", funcname)
+ return
+ case <-ticker.C:
+ c.BurnTimeUpdate()
+ }
+ }
+}
+
+func (c *Cron) BurnTimeUpdate() {
+ // funcname := "BurnTimeUpdate"
+ // get all user scores
+ // if burn time < now
+ // halv the score
+ // move the burn time to 24h from now (using same hours, minutes, seconds)
+}
diff --git a/internal/database/migrations/001_init.down.sql b/internal/database/migrations/001_init.down.sql
new file mode 100644
index 0000000..41cc5c9
--- /dev/null
+++ b/internal/database/migrations/001_init.down.sql
@@ -0,0 +1,4 @@
+BEGIN TRANSACTION;
+DROP SCHEMA public CASCADE;
+CREATE SCHEMA public;
+COMMIT;
diff --git a/internal/database/migrations/001_init.up.sql b/internal/database/migrations/001_init.up.sql
new file mode 100644
index 0000000..141a045
--- /dev/null
+++ b/internal/database/migrations/001_init.up.sql
@@ -0,0 +1,25 @@
+BEGIN TRANSACTION;
+CREATE TABLE IF NOT EXISTS user_score (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ username TEXT UNIQUE NOT NULL,
+ password TEXT NOT NULL,
+ burn_time TIMESTAMP NOT NULL,
+ score SMALLINT NOT NULL,
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+
+CREATE TABLE IF NOT EXISTS action(
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT NOT NULL,
+ magnitude SMALLINT NOT NULL DEFAULT 1,
+ repeatable BOOLEAN NOT NULL DEFAULT FALSE,
+ type TEXT NOT NULL,
+ done BOOLEAN NOT NULL DEFAULT FALSE,
+ username TEXT NOT NULL,
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(username, name),
+ CONSTRAINT fk_user_score
+ FOREIGN KEY(username)
+ REFERENCES user_score(username)
+);
+COMMIT;
diff --git a/internal/database/migrations/002_defaults.down.sql b/internal/database/migrations/002_defaults.down.sql
new file mode 100644
index 0000000..3e38fa2
--- /dev/null
+++ b/internal/database/migrations/002_defaults.down.sql
@@ -0,0 +1,5 @@
+BEGIN TRANSACTION;
+
+DROP TABLE defaults;
+
+COMMIT;
diff --git a/internal/database/migrations/002_defaults.up.sql b/internal/database/migrations/002_defaults.up.sql
new file mode 100644
index 0000000..4bf42b7
--- /dev/null
+++ b/internal/database/migrations/002_defaults.up.sql
@@ -0,0 +1,11 @@
+BEGIN TRANSACTION;
+
+CREATE TABLE IF NOT EXISTS defaults (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ key TEXT UNIQUE NOT NULL,
+ value TEXT NOT NULL,
+ description TEXT NOT NULL DEFAULT '',
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+
+COMMIT;
diff --git a/internal/database/migrations/init.go b/internal/database/migrations/init.go
new file mode 100644
index 0000000..2b5a212
--- /dev/null
+++ b/internal/database/migrations/init.go
@@ -0,0 +1,3 @@
+package migrations
+
+//go:generate go-bindata -o ./migrations.bindata.go -pkg migrations -ignore=\\*.go ./...
diff --git a/internal/database/migrations/migrations.bindata.go b/internal/database/migrations/migrations.bindata.go
new file mode 100644
index 0000000..a3276ee
--- /dev/null
+++ b/internal/database/migrations/migrations.bindata.go
@@ -0,0 +1,267 @@
+// Code generated for package migrations by go-bindata DO NOT EDIT. (@generated)
+// sources:
+// 001_init.down.sql
+// 001_init.up.sql
+package migrations
+
+import (
+ "bytes"
+ "compress/gzip"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+)
+
+func bindataRead(data []byte, name string) ([]byte, error) {
+ gz, err := gzip.NewReader(bytes.NewBuffer(data))
+ if err != nil {
+ return nil, fmt.Errorf("Read %q: %v", name, err)
+ }
+
+ var buf bytes.Buffer
+ _, err = io.Copy(&buf, gz)
+ clErr := gz.Close()
+
+ if err != nil {
+ return nil, fmt.Errorf("Read %q: %v", name, err)
+ }
+ if clErr != nil {
+ return nil, err
+ }
+
+ return buf.Bytes(), nil
+}
+
+type asset struct {
+ bytes []byte
+ info os.FileInfo
+}
+
+type bindataFileInfo struct {
+ name string
+ size int64
+ mode os.FileMode
+ modTime time.Time
+}
+
+// Name return file name
+func (fi bindataFileInfo) Name() string {
+ return fi.name
+}
+
+// Size return file size
+func (fi bindataFileInfo) Size() int64 {
+ return fi.size
+}
+
+// Mode return file mode
+func (fi bindataFileInfo) Mode() os.FileMode {
+ return fi.mode
+}
+
+// Mode return file modify time
+func (fi bindataFileInfo) ModTime() time.Time {
+ return fi.modTime
+}
+
+// IsDir return file whether a directory
+func (fi bindataFileInfo) IsDir() bool {
+ return fi.mode&os.ModeDir != 0
+}
+
+// Sys return file is sys mode
+func (fi bindataFileInfo) Sys() interface{} {
+ return nil
+}
+
+var __001_initDownSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x72\x72\x75\xf7\xf4\xb3\xe6\x72\x09\xf2\x0f\x50\x08\x76\xf6\x70\xf5\x75\x54\x28\x28\x4d\xca\xc9\x4c\x56\x70\x76\x0c\x76\x76\x74\x71\xb5\xe6\x72\x0e\x72\x75\x0c\x71\x45\x95\xb5\xe6\x72\xf6\xf7\xf5\xf5\x0c\xb1\xe6\x02\x04\x00\x00\xff\xff\x47\x78\xff\x61\x41\x00\x00\x00")
+
+func _001_initDownSqlBytes() ([]byte, error) {
+ return bindataRead(
+ __001_initDownSql,
+ "001_init.down.sql",
+ )
+}
+
+func _001_initDownSql() (*asset, error) {
+ bytes, err := _001_initDownSqlBytes()
+ if err != nil {
+ return nil, err
+ }
+
+ info := bindataFileInfo{name: "001_init.down.sql", size: 65, mode: os.FileMode(420), modTime: time.Unix(1712559447, 0)}
+ a := &asset{bytes: bytes, info: info}
+ return a, nil
+}
+
+var __001_initUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x9c\x92\x41\x6f\x82\x40\x10\x85\xef\xfc\x8a\x77\x53\x53\x2f\x9e\x3d\x2d\x38\x98\x4d\x97\xdd\x16\x96\xb4\x9e\xc8\x2a\xdb\x86\x54\xd1\xe0\xda\xc4\x7f\xdf\xe0\x5a\x28\xa9\x49\x1b\xcf\xf3\xcd\x9b\xf7\x66\x26\xa4\x25\x97\xf3\x20\x4a\x89\x69\x82\x66\xa1\x20\x9c\x8e\xb6\x29\x8e\x9b\x7d\x63\x31\x0e\x00\xa0\x2a\xc1\xa5\xc6\x92\x24\xa5\x4c\xd3\x02\xe1\x0a\x0b\x8a\x59\x2e\x34\x58\x06\xbe\x20\xa9\xb9\x5e\x4d\x2f\x70\xdb\x5d\x9b\x9d\x85\xa6\x57\x8d\x5c\xf2\xe7\x9c\x20\x95\x86\xcc\x85\xf0\xc8\xfa\xd4\xd4\x85\xab\x5a\x86\x27\x94\x69\x96\x3c\x75\x44\x27\x2c\xd5\xcb\x78\x82\x07\x54\xb5\xb3\xcd\xa7\xd9\x62\x34\x43\x69\xce\x23\x2f\xe1\xed\x65\x09\x13\xa2\xb5\x36\xd4\xdf\x34\xd6\x38\x5b\x16\xc6\xfd\x39\x20\x98\xcc\x83\x61\x7a\xb3\x71\xd5\xbe\xbe\x27\x79\x9f\x7a\x68\x67\x67\xde\xeb\xca\x9d\xca\x1b\x7e\x3b\xb1\x99\x47\x1b\x7b\xb0\xc6\x99\xf5\xd6\x22\x54\x4a\x10\x93\xbf\xd1\x98\x89\x8c\x3c\xee\xce\x87\x9b\x13\xcb\x7d\xfd\x3f\x81\xe1\xb1\xee\xde\xa2\x6f\xf0\xb7\x1e\x7f\x6b\x4e\x2f\x0b\xb9\xd6\x22\x25\x33\x9d\xb2\x36\xfb\xdb\x47\xd1\x7f\xd8\xa5\x08\xc4\x2a\x25\xbe\x94\x78\xa4\x55\xd7\x3f\xc1\xb5\x98\x52\x4c\x29\xc9\x88\xb2\x1f\xaf\xd9\x63\xed\x09\x23\x95\x24\x5c\xcf\x83\xaf\x00\x00\x00\xff\xff\x9f\xe9\x59\xfc\xcf\x02\x00\x00")
+
+func _001_initUpSqlBytes() ([]byte, error) {
+ return bindataRead(
+ __001_initUpSql,
+ "001_init.up.sql",
+ )
+}
+
+func _001_initUpSql() (*asset, error) {
+ bytes, err := _001_initUpSqlBytes()
+ if err != nil {
+ return nil, err
+ }
+
+ info := bindataFileInfo{name: "001_init.up.sql", size: 719, mode: os.FileMode(420), modTime: time.Unix(1712558242, 0)}
+ a := &asset{bytes: bytes, info: info}
+ return a, nil
+}
+
+// Asset loads and returns the asset for the given name.
+// It returns an error if the asset could not be found or
+// could not be loaded.
+func Asset(name string) ([]byte, error) {
+ cannonicalName := strings.Replace(name, "\\", "/", -1)
+ if f, ok := _bindata[cannonicalName]; ok {
+ a, err := f()
+ if err != nil {
+ return nil, fmt.Errorf("Asset %s can't read by error: %v", name, err)
+ }
+ return a.bytes, nil
+ }
+ return nil, fmt.Errorf("Asset %s not found", name)
+}
+
+// MustAsset is like Asset but panics when Asset would return an error.
+// It simplifies safe initialization of global variables.
+func MustAsset(name string) []byte {
+ a, err := Asset(name)
+ if err != nil {
+ panic("asset: Asset(" + name + "): " + err.Error())
+ }
+
+ return a
+}
+
+// AssetInfo loads and returns the asset info for the given name.
+// It returns an error if the asset could not be found or
+// could not be loaded.
+func AssetInfo(name string) (os.FileInfo, error) {
+ cannonicalName := strings.Replace(name, "\\", "/", -1)
+ if f, ok := _bindata[cannonicalName]; ok {
+ a, err := f()
+ if err != nil {
+ return nil, fmt.Errorf("AssetInfo %s can't read by error: %v", name, err)
+ }
+ return a.info, nil
+ }
+ return nil, fmt.Errorf("AssetInfo %s not found", name)
+}
+
+// AssetNames returns the names of the assets.
+func AssetNames() []string {
+ names := make([]string, 0, len(_bindata))
+ for name := range _bindata {
+ names = append(names, name)
+ }
+ return names
+}
+
+// _bindata is a table, holding each asset generator, mapped to its name.
+var _bindata = map[string]func() (*asset, error){
+ "001_init.down.sql": _001_initDownSql,
+ "001_init.up.sql": _001_initUpSql,
+}
+
+// AssetDir returns the file names below a certain
+// directory embedded in the file by go-bindata.
+// For example if you run go-bindata on data/... and data contains the
+// following hierarchy:
+// data/
+// foo.txt
+// img/
+// a.png
+// b.png
+// then AssetDir("data") would return []string{"foo.txt", "img"}
+// AssetDir("data/img") would return []string{"a.png", "b.png"}
+// AssetDir("foo.txt") and AssetDir("notexist") would return an error
+// AssetDir("") will return []string{"data"}.
+func AssetDir(name string) ([]string, error) {
+ node := _bintree
+ if len(name) != 0 {
+ cannonicalName := strings.Replace(name, "\\", "/", -1)
+ pathList := strings.Split(cannonicalName, "/")
+ for _, p := range pathList {
+ node = node.Children[p]
+ if node == nil {
+ return nil, fmt.Errorf("Asset %s not found", name)
+ }
+ }
+ }
+ if node.Func != nil {
+ return nil, fmt.Errorf("Asset %s not found", name)
+ }
+ rv := make([]string, 0, len(node.Children))
+ for childName := range node.Children {
+ rv = append(rv, childName)
+ }
+ return rv, nil
+}
+
+type bintree struct {
+ Func func() (*asset, error)
+ Children map[string]*bintree
+}
+
+var _bintree = &bintree{nil, map[string]*bintree{
+ "001_init.down.sql": &bintree{_001_initDownSql, map[string]*bintree{}},
+ "001_init.up.sql": &bintree{_001_initUpSql, map[string]*bintree{}},
+}}
+
+// RestoreAsset restores an asset under the given directory
+func RestoreAsset(dir, name string) error {
+ data, err := Asset(name)
+ if err != nil {
+ return err
+ }
+ info, err := AssetInfo(name)
+ if err != nil {
+ return err
+ }
+ err = os.MkdirAll(_filePath(dir, filepath.Dir(name)), os.FileMode(0755))
+ if err != nil {
+ return err
+ }
+ err = ioutil.WriteFile(_filePath(dir, name), data, info.Mode())
+ if err != nil {
+ return err
+ }
+ err = os.Chtimes(_filePath(dir, name), info.ModTime(), info.ModTime())
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// RestoreAssets restores an asset under the given directory recursively
+func RestoreAssets(dir, name string) error {
+ children, err := AssetDir(name)
+ // File
+ if err != nil {
+ return RestoreAsset(dir, name)
+ }
+ // Dir
+ for _, child := range children {
+ err = RestoreAssets(dir, filepath.Join(name, child))
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func _filePath(dir, name string) string {
+ cannonicalName := strings.Replace(name, "\\", "/", -1)
+ return filepath.Join(append([]string{dir}, strings.Split(cannonicalName, "/")...)...)
+}
diff --git a/internal/database/repos/defaults.go b/internal/database/repos/defaults.go
new file mode 100644
index 0000000..e0b5e52
--- /dev/null
+++ b/internal/database/repos/defaults.go
@@ -0,0 +1,25 @@
+package repos
+
+type DefaultsRepo interface {
+ DBGetDefaultsMap() (map[string]string, error)
+}
+
+func (p *Provider) DBGetDefaultsMap() (map[string]string, error) {
+ rows, err := p.db.Queryx(`SELECT key, value
+ FROM defaults;
+ `)
+ if err != nil {
+ return nil, err
+ }
+ res := make(map[string]string)
+ for rows.Next() {
+ keyval, err := rows.SliceScan()
+ if err != nil {
+ return nil, err
+ }
+ key := keyval[0].(string)
+ value := keyval[1].(string)
+ res[key] = value
+ }
+ return res, nil
+}
diff --git a/internal/database/repos/main.go b/internal/database/repos/main.go
new file mode 100644
index 0000000..e57855f
--- /dev/null
+++ b/internal/database/repos/main.go
@@ -0,0 +1,19 @@
+package repos
+
+import (
+ "github.com/jmoiron/sqlx"
+)
+
+type FullRepo interface {
+ DefaultsRepo
+}
+
+type Provider struct {
+ db *sqlx.DB
+}
+
+func NewProvider(conn *sqlx.DB) *Provider {
+ return &Provider{
+ db: conn,
+ }
+}
diff --git a/internal/database/sql/main.go b/internal/database/sql/main.go
new file mode 100644
index 0000000..5a523f6
--- /dev/null
+++ b/internal/database/sql/main.go
@@ -0,0 +1,88 @@
+package database
+
+import (
+ "os"
+ "time"
+
+ "log/slog"
+
+ "github.com/jmoiron/sqlx"
+ _ "github.com/mattn/go-sqlite3"
+ "github.com/pkg/errors"
+)
+
+var (
+ log = slog.New(slog.NewJSONHandler(os.Stdout, nil))
+ dbDriver = "sqlite3"
+)
+
+type DB struct {
+ Conn *sqlx.DB
+ URI string
+}
+
+func (d *DB) CloseAll() error {
+ for _, conn := range []*sqlx.DB{d.Conn} {
+ if err := closeConn(conn); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func closeConn(conn *sqlx.DB) error {
+ return conn.Close()
+}
+
+func Init(DBURI string) (*DB, error) {
+ var result DB
+ var err error
+ result.Conn, err = openDBConnection(DBURI, dbDriver)
+ if err != nil {
+ return nil, err
+ }
+ result.URI = DBURI
+ if err := testConnection(result.Conn); err != nil {
+ return nil, err
+ }
+ return &result, nil
+}
+
+func openDBConnection(dbURI, driver string) (*sqlx.DB, error) {
+ conn, err := sqlx.Open(driver, dbURI)
+ if err != nil {
+ return nil, err
+ }
+ return conn, nil
+}
+
+func testConnection(conn *sqlx.DB) error {
+ err := conn.Ping()
+ if err != nil {
+ return errors.Wrap(err, "can't ping database")
+ }
+ return nil
+}
+
+func (d *DB) PingRoutine(interval time.Duration) {
+ ticker := time.NewTicker(interval)
+ done := make(chan bool)
+ for {
+ select {
+ case <-done:
+ return
+ case t := <-ticker.C:
+ if err := testConnection(d.Conn); err != nil {
+ log.Error("failed to ping postrges db", "error", err, "ping_at", t)
+ // reconnect
+ if err := closeConn(d.Conn); err != nil {
+ log.Error("failed to close db connection", "error", err, "ping_at", t)
+ }
+ d.Conn, err = openDBConnection(d.URI, dbDriver)
+ if err != nil {
+ log.Error("failed to reconnect", "error", err, "ping_at", t)
+ }
+ }
+ }
+ }
+}
diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go
new file mode 100644
index 0000000..d0ccdff
--- /dev/null
+++ b/internal/handlers/auth.go
@@ -0,0 +1,162 @@
+package handlers
+
+import (
+ "crypto/hmac"
+ "crypto/sha256"
+ "demoon/internal/models"
+ "demoon/pkg/utils"
+ "encoding/base64"
+ "encoding/json"
+ "html/template"
+ "net/http"
+ "strings"
+ "time"
+
+ "golang.org/x/crypto/bcrypt"
+)
+
+func abortWithError(w http.ResponseWriter, msg string) {
+ tmpl := template.Must(template.ParseGlob("components/*.html"))
+ tmpl.ExecuteTemplate(w, "error", msg)
+}
+
+func (h *Handlers) HandleSignup(w http.ResponseWriter, r *http.Request) {
+ r.ParseForm()
+ username := r.PostFormValue("username")
+ if username == "" {
+ msg := "username not provided"
+ h.log.Error(msg)
+ abortWithError(w, msg)
+ return
+ }
+ password := r.PostFormValue("password")
+ if password == "" {
+ msg := "password not provided"
+ h.log.Error(msg)
+ abortWithError(w, msg)
+ return
+ }
+ // make sure username does not exists
+ cleanName := utils.RemoveSpacesFromStr(username)
+ hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), 8)
+ // create user in db
+ now := time.Now()
+ nextMidnight := time.Date(now.Year(), now.Month(), now.Day(),
+ 0, 0, 0, 0, time.UTC).Add(time.Hour * 24)
+ newUser := &models.UserScore{
+ Username: cleanName, Password: string(hashedPassword),
+ BurnTime: nextMidnight, CreatedAt: now,
+ }
+ // login user
+ cookie, err := h.makeCookie(cleanName, r.RemoteAddr)
+ if err != nil {
+ h.log.Error("failed to login", "error", err)
+ abortWithError(w, err.Error())
+ return
+ }
+ http.SetCookie(w, cookie)
+ // http.Redirect(w, r, "/", 302)
+ tmpl, err := template.ParseGlob("components/*.html")
+ if err != nil {
+ abortWithError(w, err.Error())
+ return
+ }
+ tmpl.ExecuteTemplate(w, "main", newUser)
+}
+
+func (h *Handlers) HandleLogin(w http.ResponseWriter, r *http.Request) {
+ r.ParseForm()
+ username := r.PostFormValue("username")
+ if username == "" {
+ msg := "username not provided"
+ h.log.Error(msg)
+ abortWithError(w, msg)
+ return
+ }
+ password := r.PostFormValue("password")
+ if password == "" {
+ msg := "password not provided"
+ h.log.Error(msg)
+ abortWithError(w, msg)
+ return
+ }
+ cleanName := utils.RemoveSpacesFromStr(username)
+ tmpl, err := template.ParseGlob("components/*.html")
+ if err != nil {
+ abortWithError(w, err.Error())
+ return
+ }
+ cookie, err := h.makeCookie(cleanName, r.RemoteAddr)
+ if err != nil {
+ h.log.Error("failed to login", "error", err)
+ abortWithError(w, err.Error())
+ return
+ }
+ http.SetCookie(w, cookie)
+ tmpl.ExecuteTemplate(w, "main", nil)
+}
+
+func (h *Handlers) makeCookie(username string, remote string) (*http.Cookie, error) {
+ // secret
+ // Create a new random session token
+ // sessionToken := xid.New().String()
+ sessionToken := "token"
+ expiresAt := time.Now().Add(time.Duration(h.cfg.SessionLifetime) * time.Second)
+ // Set the token in the session map, along with the session information
+ session := &models.Session{
+ Username: username,
+ Expiry: expiresAt,
+ }
+ cookieName := "session_token"
+ // hmac to protect cookies
+ hm := hmac.New(sha256.New, []byte(h.cfg.CookieSecret))
+ hm.Write([]byte(cookieName))
+ hm.Write([]byte(sessionToken))
+ signature := hm.Sum(nil)
+ // b64 enc to avoid non-ascii
+ cookieValue := base64.URLEncoding.EncodeToString([]byte(
+ string(signature) + sessionToken))
+ cookie := &http.Cookie{
+ Name: cookieName,
+ Value: cookieValue,
+ Secure: true,
+ HttpOnly: true,
+ SameSite: http.SameSiteNoneMode,
+ Domain: h.cfg.ServerConfig.Host,
+ }
+ h.log.Info("check remote addr for cookie set",
+ "remote", remote, "session", session)
+ if strings.Contains(remote, "192.168.0") {
+ // no idea what is going on
+ cookie.Domain = "192.168.0.101"
+ }
+ // set ctx?
+ // set user in session
+ if err := h.cacheSetSession(sessionToken, session); err != nil {
+ return nil, err
+ }
+ return cookie, nil
+}
+
+func (h *Handlers) cacheGetSession(key string) (*models.Session, error) {
+ userSessionB, err := h.mc.Get(key)
+ if err != nil {
+ return nil, err
+ }
+ var us *models.Session
+ if err := json.Unmarshal(userSessionB, &us); err != nil {
+ return nil, err
+ }
+ return us, nil
+}
+
+func (h *Handlers) cacheSetSession(key string, session *models.Session) error {
+ sesb, err := json.Marshal(session)
+ if err != nil {
+ return err
+ }
+ h.mc.Set(key, sesb)
+ // expire in 10 min
+ h.mc.Expire(key, 10*60)
+ return nil
+}
diff --git a/internal/handlers/main.go b/internal/handlers/main.go
new file mode 100644
index 0000000..960b26d
--- /dev/null
+++ b/internal/handlers/main.go
@@ -0,0 +1,64 @@
+package handlers
+
+import (
+ "demoon/config"
+ "demoon/internal/database/repos"
+ "demoon/internal/models"
+ "demoon/pkg/cache"
+ "html/template"
+ "log/slog"
+ "net/http"
+ "os"
+)
+
+var defUS = models.UserScore{}
+
+// Handlers structure
+type Handlers struct {
+ cfg config.Config
+ log *slog.Logger
+ repo repos.FullRepo
+ mc cache.Cache
+}
+
+// NewHandlers constructor
+func NewHandlers(
+ cfg config.Config, l *slog.Logger, repo repos.FullRepo,
+) *Handlers {
+ if l == nil {
+ l = slog.New(slog.NewJSONHandler(os.Stdout, nil))
+ }
+ h := &Handlers{
+ cfg: cfg,
+ log: l,
+ repo: repo,
+ mc: cache.MemCache,
+ }
+ return h
+}
+
+func (h *Handlers) Ping(w http.ResponseWriter, r *http.Request) {
+ h.log.Info("got ping request")
+ w.Write([]byte("pong"))
+}
+
+func (h *Handlers) MainPage(w http.ResponseWriter, r *http.Request) {
+ tmpl, err := template.ParseGlob("components/*.html")
+ if err != nil {
+ abortWithError(w, err.Error())
+ return
+ }
+ // get recommendations
+ usernameRaw := r.Context().Value("username")
+ h.log.Info("got mainpage request", "username", usernameRaw)
+ if usernameRaw == nil {
+ tmpl.ExecuteTemplate(w, "main", defUS)
+ return
+ }
+ username := usernameRaw.(string)
+ if username == "" {
+ tmpl.ExecuteTemplate(w, "main", defUS)
+ return
+ }
+ tmpl.ExecuteTemplate(w, "main", nil)
+}
diff --git a/internal/handlers/middleware.go b/internal/handlers/middleware.go
new file mode 100644
index 0000000..8b871a2
--- /dev/null
+++ b/internal/handlers/middleware.go
@@ -0,0 +1,75 @@
+package handlers
+
+import (
+ "context"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "errors"
+ "net/http"
+)
+
+func (h *Handlers) GetSession(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ cookieName := "session_token"
+ sessionCookie, err := r.Cookie(cookieName)
+ if err != nil {
+ msg := "auth failed; failed to get session token from cookies"
+ h.log.Debug(msg, "error", err)
+ next.ServeHTTP(w, r)
+ return
+ }
+ cookieValueB, err := base64.URLEncoding.
+ DecodeString(sessionCookie.Value)
+ if err != nil {
+ msg := "auth failed; failed to decode b64 cookie"
+ h.log.Debug(msg, "error", err)
+ next.ServeHTTP(w, r)
+ return
+ }
+ cookieValue := string(cookieValueB)
+ if len(cookieValue) < sha256.Size {
+ h.log.Warn("small cookie", "size", len(cookieValue))
+ next.ServeHTTP(w, r)
+ return
+ }
+ // Split apart the signature and original cookie value.
+ signature := cookieValue[:sha256.Size]
+ sessionToken := cookieValue[sha256.Size:]
+ //verify signature
+ mac := hmac.New(sha256.New, []byte(h.cfg.CookieSecret))
+ mac.Write([]byte(cookieName))
+ mac.Write([]byte(sessionToken))
+ expectedSignature := mac.Sum(nil)
+ if !hmac.Equal([]byte(signature), expectedSignature) {
+ h.log.Debug("cookie with an invalid sign")
+ next.ServeHTTP(w, r)
+ return
+ }
+ userSession, err := h.cacheGetSession(sessionToken)
+ if err != nil {
+ msg := "auth failed; session does not exists"
+ err = errors.New(msg)
+ h.log.Debug(msg, "error", err)
+ next.ServeHTTP(w, r)
+ return
+ }
+ if userSession.IsExpired() {
+ h.mc.RemoveKey(sessionToken)
+ msg := "session is expired"
+ h.log.Debug(msg, "error", err, "token", sessionToken)
+ next.ServeHTTP(w, r)
+ return
+ }
+ ctx := context.WithValue(r.Context(),
+ "username", userSession.Username)
+ if err := h.cacheSetSession(sessionToken,
+ userSession); err != nil {
+ msg := "failed to marshal user session"
+ h.log.Warn(msg, "error", err)
+ next.ServeHTTP(w, r)
+ return
+ }
+ next.ServeHTTP(w, r.WithContext(ctx))
+ })
+}
diff --git a/internal/models/auth.go b/internal/models/auth.go
new file mode 100644
index 0000000..9964cd5
--- /dev/null
+++ b/internal/models/auth.go
@@ -0,0 +1,24 @@
+package models
+
+import (
+ "time"
+)
+
+// each session contains the username of the user and the time at which it expires
+type Session struct {
+ Username string
+ Expiry time.Time
+}
+
+// we'll use this method later to determine if the session has expired
+func (s Session) IsExpired() bool {
+ return s.Expiry.Before(time.Now())
+}
+
+func ListUsernames(ss map[string]*Session) []string {
+ resp := make([]string, 0, len(ss))
+ for _, s := range ss {
+ resp = append(resp, s.Username)
+ }
+ return resp
+}
diff --git a/internal/models/models.go b/internal/models/models.go
new file mode 100644
index 0000000..5bc120b
--- /dev/null
+++ b/internal/models/models.go
@@ -0,0 +1,45 @@
+package models
+
+import "time"
+
+type ActionType string
+
+const (
+ ActionTypePlus ActionType = "ActionTypePlus"
+ ActionTypeMinus ActionType = "ActionTypeMinus"
+)
+
+type (
+ UserScore struct {
+ ID uint32 `db:"id"`
+ Username string `db:"username"`
+ Password string `db:"password"`
+ Actions []Action
+ Recommendations []Action
+ BurnTime time.Time `db:"burn_time"`
+ Score int8 `db:"score"`
+ CreatedAt time.Time `db:"created_at"`
+ }
+ Action struct {
+ ID uint32 `db:"id"`
+ Name string `db:"name"`
+ Magnitude uint8 `db:"magnitude"`
+ Repeatable bool `db:"repeatable"`
+ Type ActionType `db:"type"`
+ Done bool `db:"done"`
+ Username string `db:"username"`
+ CreatedAt time.Time `db:"created_at"`
+ }
+)
+
+func (us *UserScore) UpdateScore(act *Action) {
+ switch act.Type {
+ case ActionTypePlus:
+ us.Score += int8(act.Magnitude)
+ if !act.Repeatable {
+ act.Done = true
+ }
+ case ActionTypeMinus:
+ us.Score -= int8(act.Magnitude)
+ }
+}
diff --git a/internal/server/main.go b/internal/server/main.go
new file mode 100644
index 0000000..2c3c8d6
--- /dev/null
+++ b/internal/server/main.go
@@ -0,0 +1,59 @@
+package server
+
+import (
+ "demoon/config"
+ "demoon/internal/database/repos"
+ "demoon/internal/handlers"
+
+ "context"
+ "log/slog"
+ "os"
+ "os/signal"
+ "syscall"
+)
+
+// Server interface
+type Server interface {
+ Listen()
+}
+
+type server struct {
+ config config.Config
+ actions *handlers.Handlers
+ ctx context.Context
+ close context.CancelFunc
+}
+
+func (srv *server) stopOnSignal(close context.CancelFunc) {
+ // listen for termination signals
+ sigc := make(chan os.Signal, 1)
+ signal.Notify(sigc, os.Interrupt, syscall.SIGINT)
+ signal.Notify(sigc, os.Interrupt, syscall.SIGTERM)
+ sig := <-sigc
+
+ log := slog.New(slog.NewJSONHandler(os.Stdout, nil))
+ log.Info("Shutting down services",
+ "section", "server",
+ "app_event", "terminate",
+ "signal", sig.String())
+ close()
+ os.Exit(0)
+}
+
+func NewServer(cfg config.Config, log *slog.Logger, repo repos.FullRepo) Server {
+ ctx, close := context.WithCancel(context.Background())
+ actions := handlers.NewHandlers(cfg, log, repo)
+ return &server{
+ config: cfg,
+ actions: actions,
+ ctx: ctx,
+ close: close,
+ }
+}
+
+// Listen for new events that affect the market and process them
+func (srv *server) Listen() {
+ // start the http server
+ go srv.ListenToRequests()
+ srv.stopOnSignal(srv.close)
+}
diff --git a/internal/server/router.go b/internal/server/router.go
new file mode 100644
index 0000000..989766c
--- /dev/null
+++ b/internal/server/router.go
@@ -0,0 +1,31 @@
+package server
+
+import (
+ "fmt"
+ "net/http"
+ "time"
+)
+
+func (srv *server) ListenToRequests() {
+ h := srv.actions
+ mux := http.NewServeMux()
+ server := &http.Server{
+ Addr: fmt.Sprintf("localhost:%s", srv.config.ServerConfig.Port),
+ Handler: h.GetSession(mux),
+ ReadTimeout: time.Second * 5,
+ WriteTimeout: time.Second * 5,
+ }
+
+ fs := http.FileServer(http.Dir("assets/"))
+ mux.Handle("GET /assets/", http.StripPrefix("/assets/", fs))
+
+ mux.HandleFunc("GET /ping", h.Ping)
+ mux.HandleFunc("GET /", h.MainPage)
+ mux.HandleFunc("POST /login", h.HandleLogin)
+ mux.HandleFunc("POST /signup", h.HandleSignup)
+
+ // ====== elements ======
+
+ fmt.Println("Listening", "addr", server.Addr)
+ server.ListenAndServe()
+}