diff options
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | cmd/start.go | 3 | ||||
-rw-r--r-- | components/auth.html | 21 | ||||
-rw-r--r-- | components/index.html | 7 | ||||
-rw-r--r-- | internal/database/repos/main.go | 5 | ||||
-rw-r--r-- | internal/handlers/auth.go | 55 | ||||
-rw-r--r-- | internal/handlers/main.go | 26 | ||||
-rw-r--r-- | internal/handlers/middleware.go | 48 | ||||
-rw-r--r-- | internal/server/router.go | 4 | ||||
-rw-r--r-- | pkg/cache/interface.go | 9 | ||||
-rw-r--r-- | pkg/cache/main.go | 146 |
11 files changed, 301 insertions, 24 deletions
@@ -2,3 +2,4 @@ secrethitler config.yml apjournal tags +store.json diff --git a/cmd/start.go b/cmd/start.go index a1f051c..5d2f5b9 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -31,7 +31,8 @@ var startCmd = &cobra.Command{ cfg := config.LoadConfig(viper.GetViper()) db, err := database.InitWithMigrate(cfg.DBURI, true) if err != nil { - panic(err) + log.Error("failed to connect to db", "error", err) + return } srv := server.NewServer(cfg, log, db.Conn) // listen for new messages diff --git a/components/auth.html b/components/auth.html new file mode 100644 index 0000000..78ad338 --- /dev/null +++ b/components/auth.html @@ -0,0 +1,21 @@ +{{define "auth"}} +<div id="logindiv"> + <form class="space-y-6" hx-post="/login" hx-target="#ancestor" hx-swap="outerHTML"> + <div> + <label For="username" class="block text-sm font-medium leading-6 text-white-900">username</label> + <div class="mt-2"> + <input id="username" name="username" autocomplete="username" class="rounded-md text-center text-black" required /> + </div> + </div> + <div> + <label For="room_pass" class="block text-sm font-medium leading-6 text-white-900">password</label> + <div class="mt-2"> + <input id="password" name="password" type="password" required class="rounded-md text-center text-black" /> + </div> + </div> + <div> + <button type="submit" class="justify-center rounded-md bg-indigo-600 px-3 py-1.5 text-sm font-semibold leading-6 text-white shadow-sm hover:bg-indigo-500 focus-visible:outline focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-indigo-600">Sign in</button> + </div> + </form> +</div> +{{end}} diff --git a/components/index.html b/components/index.html index b3da4ea..3693a56 100644 --- a/components/index.html +++ b/components/index.html @@ -10,13 +10,18 @@ </head> <body> <div id="ancestor"> - <hr /> + {{ if not . }} + <div> + {{ template "auth" }} + </div> + {{ else }} <div> {{ template "UserScore" . }} </div> <div> {{ template "showformbtn" }} </div> + {{ end }} </div> </body> </html> diff --git a/internal/database/repos/main.go b/internal/database/repos/main.go index e6f50be..25eeec4 100644 --- a/internal/database/repos/main.go +++ b/internal/database/repos/main.go @@ -4,6 +4,11 @@ import ( "github.com/jmoiron/sqlx" ) +type FullRepo interface { + ActionRepo + UserScoreRepo +} + type Provider struct { db *sqlx.DB } diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go index 435f8ff..5ec1c80 100644 --- a/internal/handlers/auth.go +++ b/internal/handlers/auth.go @@ -3,7 +3,7 @@ package handlers import ( "apjournal/internal/models" "apjournal/pkg/utils" - "fmt" + "encoding/json" "html/template" "net/http" "strings" @@ -32,21 +32,32 @@ func (h *Handlers) HandleLogin(w http.ResponseWriter, r *http.Request) { return } cleanName := utils.RemoveSpacesFromStr(username) - // allNames := h.s.CacheGetAllNames() - allNames := []string{} - if utils.StrInSlice(cleanName, allNames) { - err := fmt.Errorf("name: %s already taken", cleanName) - h.log.Error("already taken", "error", err) - abortWithError(w, err.Error()) - return - } cookie, err := h.makeCookie(cleanName, r.RemoteAddr) if err != nil { h.log.Error("failed to login", "error", err) abortWithError(w, err.Error()) + return } http.SetCookie(w, cookie) - http.Redirect(w, r, "/", 302) + // http.Redirect(w, r, "/", 302) + tmpl, err := template.ParseGlob("components/*.html") + if err != nil { + panic(err) + } + userScore, err := h.repo.DBUserScoreGet(cleanName) + if err != nil { + h.log.Warn("got db err", "err", err) + if err := h.repo.DBUserScoreCreate(&us); err != nil { + panic(err) + } + tmpl.ExecuteTemplate(w, "main", nil) + return + } + userScore.Actions, err = h.repo.DBActionList(cleanName) + if err != nil { + panic(err) + } + tmpl.ExecuteTemplate(w, "main", userScore) } func (h *Handlers) makeCookie(username string, remote string) (*http.Cookie, error) { @@ -72,10 +83,32 @@ func (h *Handlers) makeCookie(username string, remote string) (*http.Cookie, err "remote", remote, "session", session) if strings.Contains(remote, "192.168.0") { // no idea what is going on - // domainName = "192.168.0.101" cookie.Domain = "192.168.0.101" } // set ctx? // c.Set("username", username) return cookie, nil } + +func (h *Handlers) cacheGetSession(key string) (*models.Session, error) { + userSessionB, err := h.mc.Get(key) + if err != nil { + return nil, err + } + var us *models.Session + if err := json.Unmarshal(userSessionB, &us); err != nil { + return nil, err + } + return us, nil +} + +func (h *Handlers) cacheSetSession(key string, session *models.Session) error { + sesb, err := json.Marshal(session) + if err != nil { + return err + } + h.mc.Set(key, sesb) + // expire in 10 min + h.mc.Expire(key, 10*60) + return nil +} diff --git a/internal/handlers/main.go b/internal/handlers/main.go index d779f42..aa9db4f 100644 --- a/internal/handlers/main.go +++ b/internal/handlers/main.go @@ -4,6 +4,7 @@ import ( "apjournal/config" "apjournal/internal/database/repos" "apjournal/internal/models" + "apjournal/pkg/cache" "html/template" "log/slog" "net/http" @@ -18,12 +19,12 @@ import ( type Handlers struct { cfg config.Config log *slog.Logger - repo *repos.Provider + repo repos.FullRepo + mc cache.Cache } // NewHandlers constructor func NewHandlers( - // cfg config.Config, s *service.Service, l *slog.Logger, cfg config.Config, l *slog.Logger, conn *sqlx.DB, ) *Handlers { if l == nil { @@ -33,6 +34,7 @@ func NewHandlers( cfg: cfg, log: l, repo: repos.NewProvider(conn), + mc: cache.MemCache, } return h } @@ -50,13 +52,22 @@ func (h *Handlers) Ping(w http.ResponseWriter, r *http.Request) { } func (h *Handlers) MainPage(w http.ResponseWriter, r *http.Request) { - h.log.Info("got mainpage request") - // tmpl := template.Must(template.ParseFiles("components/index.html")) tmpl, err := template.ParseGlob("components/*.html") if err != nil { panic(err) } - userScore, err := h.repo.DBUserScoreGet("test") + usernameRaw := r.Context().Value("username") + h.log.Info("got mainpage request", "username", usernameRaw) + if usernameRaw == nil { + tmpl.ExecuteTemplate(w, "main", nil) + return + } + username := usernameRaw.(string) + if username == "" { + tmpl.ExecuteTemplate(w, "main", nil) + return + } + userScore, err := h.repo.DBUserScoreGet(username) if err != nil { h.log.Warn("got db err", "err", err) if err := h.repo.DBUserScoreCreate(&us); err != nil { @@ -65,13 +76,10 @@ func (h *Handlers) MainPage(w http.ResponseWriter, r *http.Request) { tmpl.ExecuteTemplate(w, "main", us) return } - userScore.Actions, err = h.repo.DBActionList("test") + userScore.Actions, err = h.repo.DBActionList(username) if err != nil { panic(err) } - // tmpl.Execute(w, us) - // us.Username = "test" - // us.BurnTime = time.Now().Add(time.Duration(24) * time.Hour) tmpl.ExecuteTemplate(w, "main", userScore) } diff --git a/internal/handlers/middleware.go b/internal/handlers/middleware.go new file mode 100644 index 0000000..28ccdbc --- /dev/null +++ b/internal/handlers/middleware.go @@ -0,0 +1,48 @@ +package handlers + +import ( + "context" + "errors" + "net/http" +) + +func (h *Handlers) GetSession(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sessionCookie, err := r.Cookie("session_token") + if err != nil { + msg := "auth failed; failed to get session token from cookies" + h.log.Debug(msg, "error", err) + next.ServeHTTP(w, r) + return + } + sessionToken := "" + if sessionCookie.Value == "" { + sessionToken = sessionCookie.Value + } + userSession, err := h.cacheGetSession(sessionCookie.Value) + if err != nil { + msg := "auth failed; session does not exists" + err = errors.New(msg) + h.log.Debug(msg, "error", err) + next.ServeHTTP(w, r) + return + } + if userSession.IsExpired() { + h.mc.RemoveKey(sessionToken) + msg := "session is expired" + h.log.Debug(msg, "error", err, "token", sessionToken) + next.ServeHTTP(w, r) + return + } + ctx := context.WithValue(r.Context(), + "username", userSession.Username) + if err := h.cacheSetSession(sessionToken, + userSession); err != nil { + msg := "failed to marshal user session" + h.log.Warn(msg, "error", err) + next.ServeHTTP(w, r) + return + } + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/internal/server/router.go b/internal/server/router.go index 36c2083..3e8785f 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -10,8 +10,8 @@ func (srv *server) ListenToRequests() { h := srv.actions mux := http.NewServeMux() server := &http.Server{ - Addr: fmt.Sprintf("localhost:%d", srv.config.ServerConfig), - Handler: mux, + Addr: fmt.Sprintf("localhost:%s", srv.config.ServerConfig.Port), + Handler: h.GetSession(mux), ReadTimeout: time.Second * 5, WriteTimeout: time.Second * 5, } diff --git a/pkg/cache/interface.go b/pkg/cache/interface.go new file mode 100644 index 0000000..606f50f --- /dev/null +++ b/pkg/cache/interface.go @@ -0,0 +1,9 @@ +package cache + +type Cache interface { + Get(key string) ([]byte, error) + Set(key string, value []byte) + Expire(key string, exp int64) + GetAll() (resp map[string][]byte) + RemoveKey(key string) +} diff --git a/pkg/cache/main.go b/pkg/cache/main.go new file mode 100644 index 0000000..d617f49 --- /dev/null +++ b/pkg/cache/main.go @@ -0,0 +1,146 @@ +package cache + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "log/slog" + "os" + "sync" + "time" +) + +const storeFileName = "store.json" + +// var MemCache Cache +var ( + MemCache *MemoryCache + log = slog.New(slog.NewJSONHandler(os.Stdout, nil)) +) + +func readJSON(fileName string) (map[string][]byte, error) { + data := make(map[string][]byte) + file, err := os.Open(fileName) + if err != nil { + return data, err + } + defer file.Close() + decoder := json.NewDecoder(file) + if err := decoder.Decode(&data); err != nil { + return data, err + } + return data, nil +} + +func init() { + data, err := readJSON(storeFileName) + if err != nil { + log.Warn("failed to load store from file", "error", err) + } + MemCache = &MemoryCache{ + data: data, + timeMap: make(map[string]time.Time), + lock: &sync.RWMutex{}, + } + MemCache.StartExpiryRoutine(time.Minute) + MemCache.StartBackupRoutine(time.Minute) +} + +type MemoryCache struct { + data map[string][]byte + timeMap map[string]time.Time + lock *sync.RWMutex +} + +// Get a value by key from the cache +func (mc *MemoryCache) Get(key string) (value []byte, err error) { + var ok bool + mc.lock.RLock() + if value, ok = mc.data[key]; !ok { + err = fmt.Errorf("not found data in mc for the key: %v", key) + } + mc.lock.RUnlock() + return value, err +} + +// Update a single value in the cache +func (mc *MemoryCache) Set(key string, value []byte) { + // no async writing + mc.lock.Lock() + mc.data[key] = value + mc.lock.Unlock() +} + +func (mc *MemoryCache) Expire(key string, exp int64) { + mc.lock.RLock() + mc.timeMap[key] = time.Now().Add(time.Duration(exp) * time.Second) + mc.lock.RUnlock() +} + +func (mc *MemoryCache) GetAll() (resp map[string][]byte) { + resp = make(map[string][]byte) + mc.lock.RLock() + for k, v := range mc.data { + resp[k] = v + } + mc.lock.RUnlock() + return +} + +func (mc *MemoryCache) GetAllTime() (resp map[string]time.Time) { + resp = make(map[string]time.Time) + mc.lock.RLock() + for k, v := range mc.timeMap { + resp[k] = v + } + mc.lock.RUnlock() + return +} + +func (mc *MemoryCache) RemoveKey(key string) { + mc.lock.RLock() + delete(mc.data, key) + delete(mc.timeMap, key) + mc.lock.RUnlock() +} + +func (mc *MemoryCache) StartExpiryRoutine(n time.Duration) { + ticker := time.NewTicker(n) + go func() { + for { + <-ticker.C + // get all + timeData := mc.GetAllTime() + // check time + currentTS := time.Now() + for k, ts := range timeData { + if ts.Before(currentTS) { + // delete exp keys + mc.RemoveKey(k) + log.Info("remove by expiry", "key", k) + } + } + } + }() +} + +func (mc *MemoryCache) StartBackupRoutine(n time.Duration) { + ticker := time.NewTicker(n) + go func() { + for { + <-ticker.C + // get all + data := mc.GetAll() + jsonString, err := json.Marshal(data) + if err != nil { + log.Warn("failed to marshal kv store", "error", err) + continue + } + err = ioutil.WriteFile(storeFileName, jsonString, os.ModePerm) + if err != nil { + log.Warn("failed to write to json file", "error", err) + continue + } + } + }() +} |