summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--cmd/start.go3
-rw-r--r--components/auth.html21
-rw-r--r--components/index.html7
-rw-r--r--internal/database/repos/main.go5
-rw-r--r--internal/handlers/auth.go55
-rw-r--r--internal/handlers/main.go26
-rw-r--r--internal/handlers/middleware.go48
-rw-r--r--internal/server/router.go4
-rw-r--r--pkg/cache/interface.go9
-rw-r--r--pkg/cache/main.go146
11 files changed, 301 insertions, 24 deletions
diff --git a/.gitignore b/.gitignore
index 43ba29c..6b8fc35 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,3 +2,4 @@ secrethitler
config.yml
apjournal
tags
+store.json
diff --git a/cmd/start.go b/cmd/start.go
index a1f051c..5d2f5b9 100644
--- a/cmd/start.go
+++ b/cmd/start.go
@@ -31,7 +31,8 @@ var startCmd = &cobra.Command{
cfg := config.LoadConfig(viper.GetViper())
db, err := database.InitWithMigrate(cfg.DBURI, true)
if err != nil {
- panic(err)
+ log.Error("failed to connect to db", "error", err)
+ return
}
srv := server.NewServer(cfg, log, db.Conn)
// listen for new messages
diff --git a/components/auth.html b/components/auth.html
new file mode 100644
index 0000000..78ad338
--- /dev/null
+++ b/components/auth.html
@@ -0,0 +1,21 @@
+{{define "auth"}}
+<div id="logindiv">
+ <form class="space-y-6" hx-post="/login" hx-target="#ancestor" hx-swap="outerHTML">
+ <div>
+ <label For="username" class="block text-sm font-medium leading-6 text-white-900">username</label>
+ <div class="mt-2">
+ <input id="username" name="username" autocomplete="username" class="rounded-md text-center text-black" required />
+ </div>
+ </div>
+ <div>
+ <label For="room_pass" class="block text-sm font-medium leading-6 text-white-900">password</label>
+ <div class="mt-2">
+ <input id="password" name="password" type="password" required class="rounded-md text-center text-black" />
+ </div>
+ </div>
+ <div>
+ <button type="submit" class="justify-center rounded-md bg-indigo-600 px-3 py-1.5 text-sm font-semibold leading-6 text-white shadow-sm hover:bg-indigo-500 focus-visible:outline focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-indigo-600">Sign in</button>
+ </div>
+ </form>
+</div>
+{{end}}
diff --git a/components/index.html b/components/index.html
index b3da4ea..3693a56 100644
--- a/components/index.html
+++ b/components/index.html
@@ -10,13 +10,18 @@
</head>
<body>
<div id="ancestor">
- <hr />
+ {{ if not . }}
+ <div>
+ {{ template "auth" }}
+ </div>
+ {{ else }}
<div>
{{ template "UserScore" . }}
</div>
<div>
{{ template "showformbtn" }}
</div>
+ {{ end }}
</div>
</body>
</html>
diff --git a/internal/database/repos/main.go b/internal/database/repos/main.go
index e6f50be..25eeec4 100644
--- a/internal/database/repos/main.go
+++ b/internal/database/repos/main.go
@@ -4,6 +4,11 @@ import (
"github.com/jmoiron/sqlx"
)
+type FullRepo interface {
+ ActionRepo
+ UserScoreRepo
+}
+
type Provider struct {
db *sqlx.DB
}
diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go
index 435f8ff..5ec1c80 100644
--- a/internal/handlers/auth.go
+++ b/internal/handlers/auth.go
@@ -3,7 +3,7 @@ package handlers
import (
"apjournal/internal/models"
"apjournal/pkg/utils"
- "fmt"
+ "encoding/json"
"html/template"
"net/http"
"strings"
@@ -32,21 +32,32 @@ func (h *Handlers) HandleLogin(w http.ResponseWriter, r *http.Request) {
return
}
cleanName := utils.RemoveSpacesFromStr(username)
- // allNames := h.s.CacheGetAllNames()
- allNames := []string{}
- if utils.StrInSlice(cleanName, allNames) {
- err := fmt.Errorf("name: %s already taken", cleanName)
- h.log.Error("already taken", "error", err)
- abortWithError(w, err.Error())
- return
- }
cookie, err := h.makeCookie(cleanName, r.RemoteAddr)
if err != nil {
h.log.Error("failed to login", "error", err)
abortWithError(w, err.Error())
+ return
}
http.SetCookie(w, cookie)
- http.Redirect(w, r, "/", 302)
+ // http.Redirect(w, r, "/", 302)
+ tmpl, err := template.ParseGlob("components/*.html")
+ if err != nil {
+ panic(err)
+ }
+ userScore, err := h.repo.DBUserScoreGet(cleanName)
+ if err != nil {
+ h.log.Warn("got db err", "err", err)
+ if err := h.repo.DBUserScoreCreate(&us); err != nil {
+ panic(err)
+ }
+ tmpl.ExecuteTemplate(w, "main", nil)
+ return
+ }
+ userScore.Actions, err = h.repo.DBActionList(cleanName)
+ if err != nil {
+ panic(err)
+ }
+ tmpl.ExecuteTemplate(w, "main", userScore)
}
func (h *Handlers) makeCookie(username string, remote string) (*http.Cookie, error) {
@@ -72,10 +83,32 @@ func (h *Handlers) makeCookie(username string, remote string) (*http.Cookie, err
"remote", remote, "session", session)
if strings.Contains(remote, "192.168.0") {
// no idea what is going on
- // domainName = "192.168.0.101"
cookie.Domain = "192.168.0.101"
}
// set ctx?
// c.Set("username", username)
return cookie, nil
}
+
+func (h *Handlers) cacheGetSession(key string) (*models.Session, error) {
+ userSessionB, err := h.mc.Get(key)
+ if err != nil {
+ return nil, err
+ }
+ var us *models.Session
+ if err := json.Unmarshal(userSessionB, &us); err != nil {
+ return nil, err
+ }
+ return us, nil
+}
+
+func (h *Handlers) cacheSetSession(key string, session *models.Session) error {
+ sesb, err := json.Marshal(session)
+ if err != nil {
+ return err
+ }
+ h.mc.Set(key, sesb)
+ // expire in 10 min
+ h.mc.Expire(key, 10*60)
+ return nil
+}
diff --git a/internal/handlers/main.go b/internal/handlers/main.go
index d779f42..aa9db4f 100644
--- a/internal/handlers/main.go
+++ b/internal/handlers/main.go
@@ -4,6 +4,7 @@ import (
"apjournal/config"
"apjournal/internal/database/repos"
"apjournal/internal/models"
+ "apjournal/pkg/cache"
"html/template"
"log/slog"
"net/http"
@@ -18,12 +19,12 @@ import (
type Handlers struct {
cfg config.Config
log *slog.Logger
- repo *repos.Provider
+ repo repos.FullRepo
+ mc cache.Cache
}
// NewHandlers constructor
func NewHandlers(
- // cfg config.Config, s *service.Service, l *slog.Logger,
cfg config.Config, l *slog.Logger, conn *sqlx.DB,
) *Handlers {
if l == nil {
@@ -33,6 +34,7 @@ func NewHandlers(
cfg: cfg,
log: l,
repo: repos.NewProvider(conn),
+ mc: cache.MemCache,
}
return h
}
@@ -50,13 +52,22 @@ func (h *Handlers) Ping(w http.ResponseWriter, r *http.Request) {
}
func (h *Handlers) MainPage(w http.ResponseWriter, r *http.Request) {
- h.log.Info("got mainpage request")
- // tmpl := template.Must(template.ParseFiles("components/index.html"))
tmpl, err := template.ParseGlob("components/*.html")
if err != nil {
panic(err)
}
- userScore, err := h.repo.DBUserScoreGet("test")
+ usernameRaw := r.Context().Value("username")
+ h.log.Info("got mainpage request", "username", usernameRaw)
+ if usernameRaw == nil {
+ tmpl.ExecuteTemplate(w, "main", nil)
+ return
+ }
+ username := usernameRaw.(string)
+ if username == "" {
+ tmpl.ExecuteTemplate(w, "main", nil)
+ return
+ }
+ userScore, err := h.repo.DBUserScoreGet(username)
if err != nil {
h.log.Warn("got db err", "err", err)
if err := h.repo.DBUserScoreCreate(&us); err != nil {
@@ -65,13 +76,10 @@ func (h *Handlers) MainPage(w http.ResponseWriter, r *http.Request) {
tmpl.ExecuteTemplate(w, "main", us)
return
}
- userScore.Actions, err = h.repo.DBActionList("test")
+ userScore.Actions, err = h.repo.DBActionList(username)
if err != nil {
panic(err)
}
- // tmpl.Execute(w, us)
- // us.Username = "test"
- // us.BurnTime = time.Now().Add(time.Duration(24) * time.Hour)
tmpl.ExecuteTemplate(w, "main", userScore)
}
diff --git a/internal/handlers/middleware.go b/internal/handlers/middleware.go
new file mode 100644
index 0000000..28ccdbc
--- /dev/null
+++ b/internal/handlers/middleware.go
@@ -0,0 +1,48 @@
+package handlers
+
+import (
+ "context"
+ "errors"
+ "net/http"
+)
+
+func (h *Handlers) GetSession(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ sessionCookie, err := r.Cookie("session_token")
+ if err != nil {
+ msg := "auth failed; failed to get session token from cookies"
+ h.log.Debug(msg, "error", err)
+ next.ServeHTTP(w, r)
+ return
+ }
+ sessionToken := ""
+ if sessionCookie.Value == "" {
+ sessionToken = sessionCookie.Value
+ }
+ userSession, err := h.cacheGetSession(sessionCookie.Value)
+ if err != nil {
+ msg := "auth failed; session does not exists"
+ err = errors.New(msg)
+ h.log.Debug(msg, "error", err)
+ next.ServeHTTP(w, r)
+ return
+ }
+ if userSession.IsExpired() {
+ h.mc.RemoveKey(sessionToken)
+ msg := "session is expired"
+ h.log.Debug(msg, "error", err, "token", sessionToken)
+ next.ServeHTTP(w, r)
+ return
+ }
+ ctx := context.WithValue(r.Context(),
+ "username", userSession.Username)
+ if err := h.cacheSetSession(sessionToken,
+ userSession); err != nil {
+ msg := "failed to marshal user session"
+ h.log.Warn(msg, "error", err)
+ next.ServeHTTP(w, r)
+ return
+ }
+ next.ServeHTTP(w, r.WithContext(ctx))
+ })
+}
diff --git a/internal/server/router.go b/internal/server/router.go
index 36c2083..3e8785f 100644
--- a/internal/server/router.go
+++ b/internal/server/router.go
@@ -10,8 +10,8 @@ func (srv *server) ListenToRequests() {
h := srv.actions
mux := http.NewServeMux()
server := &http.Server{
- Addr: fmt.Sprintf("localhost:%d", srv.config.ServerConfig),
- Handler: mux,
+ Addr: fmt.Sprintf("localhost:%s", srv.config.ServerConfig.Port),
+ Handler: h.GetSession(mux),
ReadTimeout: time.Second * 5,
WriteTimeout: time.Second * 5,
}
diff --git a/pkg/cache/interface.go b/pkg/cache/interface.go
new file mode 100644
index 0000000..606f50f
--- /dev/null
+++ b/pkg/cache/interface.go
@@ -0,0 +1,9 @@
+package cache
+
+type Cache interface {
+ Get(key string) ([]byte, error)
+ Set(key string, value []byte)
+ Expire(key string, exp int64)
+ GetAll() (resp map[string][]byte)
+ RemoveKey(key string)
+}
diff --git a/pkg/cache/main.go b/pkg/cache/main.go
new file mode 100644
index 0000000..d617f49
--- /dev/null
+++ b/pkg/cache/main.go
@@ -0,0 +1,146 @@
+package cache
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "log/slog"
+ "os"
+ "sync"
+ "time"
+)
+
+const storeFileName = "store.json"
+
+// var MemCache Cache
+var (
+ MemCache *MemoryCache
+ log = slog.New(slog.NewJSONHandler(os.Stdout, nil))
+)
+
+func readJSON(fileName string) (map[string][]byte, error) {
+ data := make(map[string][]byte)
+ file, err := os.Open(fileName)
+ if err != nil {
+ return data, err
+ }
+ defer file.Close()
+ decoder := json.NewDecoder(file)
+ if err := decoder.Decode(&data); err != nil {
+ return data, err
+ }
+ return data, nil
+}
+
+func init() {
+ data, err := readJSON(storeFileName)
+ if err != nil {
+ log.Warn("failed to load store from file", "error", err)
+ }
+ MemCache = &MemoryCache{
+ data: data,
+ timeMap: make(map[string]time.Time),
+ lock: &sync.RWMutex{},
+ }
+ MemCache.StartExpiryRoutine(time.Minute)
+ MemCache.StartBackupRoutine(time.Minute)
+}
+
+type MemoryCache struct {
+ data map[string][]byte
+ timeMap map[string]time.Time
+ lock *sync.RWMutex
+}
+
+// Get a value by key from the cache
+func (mc *MemoryCache) Get(key string) (value []byte, err error) {
+ var ok bool
+ mc.lock.RLock()
+ if value, ok = mc.data[key]; !ok {
+ err = fmt.Errorf("not found data in mc for the key: %v", key)
+ }
+ mc.lock.RUnlock()
+ return value, err
+}
+
+// Update a single value in the cache
+func (mc *MemoryCache) Set(key string, value []byte) {
+ // no async writing
+ mc.lock.Lock()
+ mc.data[key] = value
+ mc.lock.Unlock()
+}
+
+func (mc *MemoryCache) Expire(key string, exp int64) {
+ mc.lock.RLock()
+ mc.timeMap[key] = time.Now().Add(time.Duration(exp) * time.Second)
+ mc.lock.RUnlock()
+}
+
+func (mc *MemoryCache) GetAll() (resp map[string][]byte) {
+ resp = make(map[string][]byte)
+ mc.lock.RLock()
+ for k, v := range mc.data {
+ resp[k] = v
+ }
+ mc.lock.RUnlock()
+ return
+}
+
+func (mc *MemoryCache) GetAllTime() (resp map[string]time.Time) {
+ resp = make(map[string]time.Time)
+ mc.lock.RLock()
+ for k, v := range mc.timeMap {
+ resp[k] = v
+ }
+ mc.lock.RUnlock()
+ return
+}
+
+func (mc *MemoryCache) RemoveKey(key string) {
+ mc.lock.RLock()
+ delete(mc.data, key)
+ delete(mc.timeMap, key)
+ mc.lock.RUnlock()
+}
+
+func (mc *MemoryCache) StartExpiryRoutine(n time.Duration) {
+ ticker := time.NewTicker(n)
+ go func() {
+ for {
+ <-ticker.C
+ // get all
+ timeData := mc.GetAllTime()
+ // check time
+ currentTS := time.Now()
+ for k, ts := range timeData {
+ if ts.Before(currentTS) {
+ // delete exp keys
+ mc.RemoveKey(k)
+ log.Info("remove by expiry", "key", k)
+ }
+ }
+ }
+ }()
+}
+
+func (mc *MemoryCache) StartBackupRoutine(n time.Duration) {
+ ticker := time.NewTicker(n)
+ go func() {
+ for {
+ <-ticker.C
+ // get all
+ data := mc.GetAll()
+ jsonString, err := json.Marshal(data)
+ if err != nil {
+ log.Warn("failed to marshal kv store", "error", err)
+ continue
+ }
+ err = ioutil.WriteFile(storeFileName, jsonString, os.ModePerm)
+ if err != nil {
+ log.Warn("failed to write to json file", "error", err)
+ continue
+ }
+ }
+ }()
+}