package main import ( "bytes" "crypto/aes" "crypto/cipher" "crypto/hmac" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/binary" "fmt" "hash/crc32" "time" "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" ) var sessionMap map[string]uint64 type SessionToken struct { Identifier [32]byte Verifier [32]byte } type CSRFToken struct { Seed [4]byte Path [16]byte Time int64 Checksum uint32 } func NewSessionToken(userID uint64) *SessionToken { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") identifier := make([]byte, 32) _, err := rand.Read(identifier) if err != nil { log.WithFields(log.Fields{"call": "rand.Read", "attr": "identifier", "err": err}).Debugf("") return &SessionToken{} } verifier := make([]byte, 32) _, err = rand.Read(verifier) if err != nil { log.WithFields(log.Fields{"call": "rand.Read", "attr": "verifier", "err": err}).Debugf("") return &SessionToken{} } hash := hmac.New(sha256.New, verifier) hash.Write(verifier) hashVerifier := hash.Sum(nil) t := SessionToken{} copy(t.Verifier[:], verifier[:32]) copy(t.Identifier[:], hashVerifier[:32]) return &t } func (t *SessionToken) Bytes() []byte { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") buf := new(bytes.Buffer) _ = binary.Write(buf, binary.LittleEndian, t) return buf.Bytes() } func (t *SessionToken) Load(b []byte) error { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") if len(b) != binary.Size(t) { return fmt.Errorf("wrong size") } r := bytes.NewReader(b) err := binary.Read(r, binary.LittleEndian, t) if err != nil { return err } return nil } func (t *SessionToken) GetSessionID() uint64 { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") if sessionID, ok := sessionMap[string(t.Identifier[:])]; !ok { return 0 } else { return sessionID } } func GetSessionTokenParam(c *gin.Context) (*SessionToken, error) { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") err := c.Request.ParseForm() if err != nil { log.WithFields(log.Fields{"call": "Context.Request.ParseForm", "err": err}).Errorf("") return nil, err } session := c.Param("key") return DecodeSessionToken(session) } func GetSessionTokenCookie(c *gin.Context) (*SessionToken, error) { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") if session, err := c.Cookie("session"); err != nil { log.WithFields(log.Fields{"call": "Context.Cookie", "attr": "session", "err": err}).Errorf("") return nil, err } else { return DecodeSessionToken(session) } } func (t *SessionToken) Encode() string { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") return base64.StdEncoding.EncodeToString(t.Bytes()) } func DecodeSessionToken(s string) (*SessionToken, error) { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") b, err := base64.StdEncoding.DecodeString(s) if err != nil { log.WithFields(log.Fields{"call": "base64.StdEncoding.DecodeString", "attr": "session", "err": err}).Errorf("") return nil, err } t := SessionToken{} err = t.Load(b) if err != nil { return nil, err } return &t, nil } func (t *CSRFToken) GetPath() string { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") buf := bytes.Buffer{} for _, b := range t.Path { if b > 0 { buf.WriteByte(b) } else { break } } return buf.String() } func (t *CSRFToken) Bytes() []byte { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") buf := new(bytes.Buffer) _ = binary.Write(buf, binary.LittleEndian, t) return buf.Bytes() } func (t *CSRFToken) Load(b []byte) error { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") if len(b) != binary.Size(t) { return fmt.Errorf("wrong size") } r := bytes.NewReader(b) err := binary.Read(r, binary.LittleEndian, t) if err != nil { log.WithFields(log.Fields{"call": "binary.Read", "attr": "b", "err": err}).Errorf("") return err } chk := t.Checksum t.Checksum = 0 if crc32.ChecksumIEEE(t.Bytes()) != chk { return fmt.Errorf("wrong checksum") } t.Checksum = chk return nil } func NewCSRFToken(c *gin.Context) *CSRFToken { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") t := CSRFToken{} copy(t.Path[:], c.Request.URL.Path[0:]) t.Time = time.Now().UTC().Unix() _, err := rand.Read(t.Seed[:]) log.WithFields(log.Fields{"call": "rand.Read", "attr": "seed", "err": err}).Errorf("") t.Checksum = crc32.ChecksumIEEE(t.Bytes()) return &t } func GetCSRFToken(c *gin.Context) (*CSRFToken, error) { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") if context, err := c.Cookie("context"); err != nil { return nil, fmt.Errorf("no context param") } else { return DecodeCSRFToken(context) } } func (t *CSRFToken) Encode() string { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") hash := sha256.New() hash.Write([]byte(cfg.Admin.Secrets.ContextKey)) key := hash.Sum(nil) block, _ := aes.NewCipher(key) ciphertext := make([]byte, 32) // size of CSRFToken if binary.Size(t) != 32 { log.WithFields(log.Fields{"err": fmt.Errorf("size is wrong")}).Fatalf("") } iv := make([]byte, aes.BlockSize) mode := cipher.NewCBCEncrypter(block, iv) mode.CryptBlocks(ciphertext[:], t.Bytes()) return base64.StdEncoding.EncodeToString(ciphertext) } func DecodeCSRFToken(s string) (*CSRFToken, error) { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") b, err := base64.StdEncoding.DecodeString(s) if err != nil { return nil, err } if len(b) != binary.Size(CSRFToken{}) { return nil, fmt.Errorf("wrong size") } hash := sha256.New() hash.Write([]byte(cfg.Admin.Secrets.ContextKey)) key := hash.Sum(nil) block, err := aes.NewCipher(key) iv := make([]byte, aes.BlockSize) mode := cipher.NewCBCDecrypter(block, iv) mode.CryptBlocks(b, b) t := CSRFToken{} err = t.Load(b) if err != nil { return nil, err } return &t, nil } func (t *CSRFToken) Valid() bool { log.WithFields(log.Fields{}).Debugf("starting") defer log.WithFields(log.Fields{}).Debugf("done") if time.Now().UTC().Sub(time.Unix(t.Time, 0)) > time.Duration(cfg.Admin.Secrets.ContextExpiration)*time.Second { return false } return true } func SetCSRFToken(c *gin.Context) { c.SetCookie("context", NewCSRFToken(c).Encode(), cfg.Admin.Secrets.ContextExpiration, "/", cfg.Admin.URL, false, true) return }