diff --git a/db.sql b/db.sql deleted file mode 100644 index b155d59..0000000 --- a/db.sql +++ /dev/null @@ -1,22 +0,0 @@ -CREATE TABLE `heartbeat` ( - `id` int(11) NOT NULL AUTO_INCREMENT, - `user` varchar(255) NOT NULL, - `time` datetime NOT NULL, - `entity` varchar(1024) DEFAULT NULL, - `type` varchar(255) NOT NULL, - `category` varchar(255) DEFAULT NULL, - `is_write` tinyint(4) NOT NULL, - `branch` varchar(255) DEFAULT NULL, - `language` varchar(255) DEFAULT NULL, - `project` varchar(255) DEFAULT NULL, - `operating_system` varchar(45) DEFAULT NULL, - `editor` varchar(45) DEFAULT NULL, - PRIMARY KEY (`id`) -) ENGINE=InnoDB DEFAULT CHARSET=latin1; - -CREATE TABLE `user` ( - `user_id` varchar(255) NOT NULL, - `api_key` varchar(255) NOT NULL, - PRIMARY KEY (`user_id`), - KEY `IDX_API_KEY` (`api_key`) -) ENGINE=InnoDB DEFAULT CHARSET=latin1; \ No newline at end of file diff --git a/main.go b/main.go index 648decf..b98f720 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "database/sql" "log" "net/http" "os" @@ -10,13 +9,16 @@ import ( "github.com/codegangsta/negroni" "github.com/gorilla/mux" + "github.com/jinzhu/gorm" "github.com/joho/godotenv" - "github.com/go-sql-driver/mysql" "github.com/n1try/wakapi/middlewares" "github.com/n1try/wakapi/models" "github.com/n1try/wakapi/routes" "github.com/n1try/wakapi/services" + "github.com/n1try/wakapi/utils" + + _ "github.com/jinzhu/gorm/dialects/mysql" ) func readConfig() models.Config { @@ -44,6 +46,7 @@ func readConfig() models.Config { DbUser: dbUser, DbPassword: dbPassword, DbName: dbName, + DbDialect: "mysql", } } @@ -51,22 +54,17 @@ func main() { // Read Config config := readConfig() - // Connect Database - dbConfig := mysql.Config{ - User: config.DbUser, - Passwd: config.DbPassword, - Net: "tcp", - Addr: config.DbHost, - DBName: config.DbName, - AllowNativePasswords: true, - ParseTime: true, - } - db, _ := sql.Open("mysql", dbConfig.FormatDSN()) - defer db.Close() - err := db.Ping() + // Connect to database + db, err := gorm.Open(config.DbDialect, utils.MakeConnectionString(&config)) if err != nil { - log.Fatal("Could not connect to database.") + // log.Fatal("Could not connect to database.") + log.Fatal(err) } + defer db.Close() + + // Migrate database schema + db.AutoMigrate(&models.User{}) + db.AutoMigrate(&models.Heartbeat{}).AddForeignKey("user_id", "users(id)", "RESTRICT", "RESTRICT") // Services heartbeatSrvc := &services.HeartbeatService{db} diff --git a/middlewares/authenticate.go b/middlewares/authenticate.go index f7c1cd2..26f65ba 100644 --- a/middlewares/authenticate.go +++ b/middlewares/authenticate.go @@ -33,6 +33,6 @@ func (m *AuthenticateMiddleware) Handle(w http.ResponseWriter, r *http.Request, return } - ctx := context.WithValue(r.Context(), models.UserKey, &user) + ctx := context.WithValue(r.Context(), models.UserKey, user) next(w, r.WithContext(ctx)) } diff --git a/models/config.go b/models/config.go index 615953b..436814a 100644 --- a/models/config.go +++ b/models/config.go @@ -6,4 +6,5 @@ type Config struct { DbUser string DbPassword string DbName string + DbDialect string } diff --git a/models/heartbeat.go b/models/heartbeat.go index 45d8628..086637b 100644 --- a/models/heartbeat.go +++ b/models/heartbeat.go @@ -1,25 +1,36 @@ package models import ( + "database/sql/driver" + "errors" + "fmt" "strconv" "strings" "time" + + "github.com/jinzhu/gorm" ) type HeartbeatReqTime time.Time type Heartbeat struct { - User string `json:"user"` - Entity string `json:"entity"` - Type string `json:"type"` - Category string `json:"category"` - Project string `json:"project"` - Branch string `json:"branch"` - Language string `json:"language"` - IsWrite bool `json:"is_write"` - Editor string `json:"editor"` - OperatingSystem string `json:"operating_system"` - Time HeartbeatReqTime `json:"time"` + gorm.Model + User *User `json:"user" gorm:"not_null; association_foreignkey:ID"` + UserID string `json:"-" gorm:"not_null"` + Entity string `json:"entity" gorm:"not_null"` + Type string `json:"type"` + Category string `json:"category"` + Project string `json:"project; index:idx_project"` + Branch string `json:"branch"` + Language string `json:"language" gorm:"not_null; index:idx_language"` + IsWrite bool `json:"is_write"` + Editor string `json:"editor" gorm:"not_null; index:idx_editor"` + OperatingSystem string `json:"operating_system" gorm:"not_null; index:idx_os"` + Time *HeartbeatReqTime `json:"time" gorm:"type:timestamp; default:now(); index:idx_time"` +} + +func (h *Heartbeat) Valid() bool { + return h.User != nil && h.UserID != "" && h.Entity != "" && h.Language != "" && h.Editor != "" && h.OperatingSystem != "" && h.Time != nil } func (j *HeartbeatReqTime) UnmarshalJSON(b []byte) error { @@ -33,6 +44,25 @@ func (j *HeartbeatReqTime) UnmarshalJSON(b []byte) error { return nil } +func (j *HeartbeatReqTime) Scan(value interface{}) error { + fmt.Printf("%T", value) + switch value.(type) { + case int64: + *j = HeartbeatReqTime(time.Unix(123456, 0)) + break + case time.Time: + *j = HeartbeatReqTime(value.(time.Time)) + break + default: + return errors.New(fmt.Sprintf("Unsupported type")) + } + return nil +} + +func (j HeartbeatReqTime) Value() (driver.Value, error) { + return time.Time(j), nil +} + func (j HeartbeatReqTime) String() string { t := time.Time(j) return t.Format("2006-01-02 15:04:05") diff --git a/models/user.go b/models/user.go index e4b28b3..c336b5f 100644 --- a/models/user.go +++ b/models/user.go @@ -1,6 +1,6 @@ package models type User struct { - UserId string `json:"id"` - ApiKey string `json:"api_key"` + ID string `json:"id" gorm:"primary_key"` + ApiKey string `json:"api_key" gorm:"unique"` } diff --git a/routes/heartbeat.go b/routes/heartbeat.go index 5ef13d2..d1acd9f 100644 --- a/routes/heartbeat.go +++ b/routes/heartbeat.go @@ -8,7 +8,6 @@ import ( "github.com/n1try/wakapi/services" "github.com/n1try/wakapi/utils" - _ "github.com/go-sql-driver/mysql" "github.com/n1try/wakapi/models" ) @@ -22,24 +21,32 @@ func (h *HeartbeatHandler) Post(w http.ResponseWriter, r *http.Request) { return } + var heartbeats []models.Heartbeat + user := r.Context().Value(models.UserKey).(*models.User) opSys, editor, _ := utils.ParseUserAgent(r.Header.Get("User-Agent")) dec := json.NewDecoder(r.Body) - var heartbeats []*models.Heartbeat - err := dec.Decode(&heartbeats) - if err != nil { + if err := dec.Decode(&heartbeats); err != nil { w.WriteHeader(400) w.Write([]byte(err.Error())) return } - for _, h := range heartbeats { + + for i := 0; i < len(heartbeats); i++ { + h := &heartbeats[i] h.OperatingSystem = opSys h.Editor = editor + h.User = user + h.UserID = user.ID + + if !h.Valid() { + w.WriteHeader(400) + w.Write([]byte("Invalid heartbeat object.")) + return + } } - user := r.Context().Value(models.UserKey).(*models.User) - err = h.HeartbeatSrvc.InsertBatch(heartbeats, user) - if err != nil { + if err := h.HeartbeatSrvc.InsertBatch(&heartbeats); err != nil { w.WriteHeader(500) os.Stderr.WriteString(err.Error()) return diff --git a/services/aggregation.go b/services/aggregation.go index 8eb2519..2c31700 100644 --- a/services/aggregation.go +++ b/services/aggregation.go @@ -1,16 +1,16 @@ package services import ( - "database/sql" "fmt" "log" "time" + "github.com/jinzhu/gorm" "github.com/n1try/wakapi/models" ) type AggregationService struct { - Db *sql.DB + Db *gorm.DB HeartbeatService *HeartbeatService } @@ -19,7 +19,7 @@ func (srv *AggregationService) Aggregate(from time.Time, to time.Time, user *mod if err != nil { log.Fatal(err) } - for _, h := range heartbeats { + for _, h := range *heartbeats { fmt.Printf("%+v\n", h) } } diff --git a/services/heartbeat.go b/services/heartbeat.go index b27f740..4bf6f56 100644 --- a/services/heartbeat.go +++ b/services/heartbeat.go @@ -1,78 +1,38 @@ package services import ( - "database/sql" - "errors" - "fmt" "time" + "github.com/jinzhu/gorm" "github.com/n1try/wakapi/models" + gormbulk "github.com/t-tiger/gorm-bulk-insert" ) const TableHeartbeat = "heartbeat" type HeartbeatService struct { - Db *sql.DB + Db *gorm.DB } -func (srv *HeartbeatService) InsertBatch(heartbeats []*models.Heartbeat, user *models.User) error { - qTpl := "INSERT INTO %+s (user, time, entity, type, category, is_write, project, branch, language, operating_system, editor) VALUES %+s;" - qFill := "" - vals := []interface{}{} - - for _, h := range heartbeats { - qFill = qFill + "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)," - vals = append(vals, user.UserId, h.Time.String(), h.Entity, h.Type, h.Category, h.IsWrite, h.Project, h.Branch, h.Language, h.OperatingSystem, h.Editor) +func (srv *HeartbeatService) InsertBatch(heartbeats *[]models.Heartbeat) error { + var batch []interface{} + for _, h := range *heartbeats { + batch = append(batch, h) } - q := fmt.Sprintf(qTpl, TableHeartbeat, qFill[0:len(qFill)-1]) - stmt, _ := srv.Db.Prepare(q) - result, err := stmt.Exec(vals...) - if err != nil { + if err := gormbulk.BulkInsert(srv.Db, batch, 3000); err != nil { return err } - n, err := result.RowsAffected() - if err != nil || n != int64(len(heartbeats)) { - return errors.New(fmt.Sprintf("Failed to insert %+v rows.", len(heartbeats))) - } return nil } -func (srv *HeartbeatService) GetAllFrom(date time.Time, user *models.User) ([]models.Heartbeat, error) { - q := fmt.Sprintf("SELECT user, time, language, project, operating_system, editor FROM %+s WHERE time >= ? AND user = ?", TableHeartbeat) - rows, err := srv.Db.Query(q, date.String(), user.UserId) - defer rows.Close() - if err != nil { - return make([]models.Heartbeat, 0), err - } - +func (srv *HeartbeatService) GetAllFrom(date time.Time, user *models.User) (*[]models.Heartbeat, error) { var heartbeats []models.Heartbeat - for rows.Next() { - var h models.Heartbeat - var language sql.NullString - var project sql.NullString - var operatingSystem sql.NullString - var editor sql.NullString - - err := rows.Scan(&h.User, &h.Time, &language, &project, &operatingSystem, &editor) - - if language.Valid { - h.Language = language.String - } - if project.Valid { - h.Project = project.String - } - if operatingSystem.Valid { - h.OperatingSystem = operatingSystem.String - } - if editor.Valid { - h.Editor = editor.String - } - - if err != nil { - return make([]models.Heartbeat, 0), err - } - heartbeats = append(heartbeats, h) + if err := srv.Db. + Where(&models.Heartbeat{UserID: user.ID}). + Where("time > ?", date). + Find(&heartbeats).Error; err != nil { + return nil, err } - return heartbeats, nil + return &heartbeats, nil } diff --git a/services/user.go b/services/user.go index 4c641d7..747bc26 100644 --- a/services/user.go +++ b/services/user.go @@ -1,33 +1,27 @@ package services import ( - "database/sql" - "fmt" - + "github.com/jinzhu/gorm" "github.com/n1try/wakapi/models" ) const TableUser = "user" type UserService struct { - Db *sql.DB + Db *gorm.DB } -func (srv *UserService) GetUserById(userId string) (models.User, error) { - q := fmt.Sprintf("SELECT user_id, api_key FROM %+s WHERE user_id = ?;", TableUser) - u := models.User{} - err := srv.Db.QueryRow(q, userId).Scan(&u.UserId, &u.ApiKey) - if err != nil { +func (srv *UserService) GetUserById(userId string) (*models.User, error) { + u := &models.User{} + if err := srv.Db.Where(&models.User{ID: userId}).First(u).Error; err != nil { return u, err } return u, nil } -func (srv *UserService) GetUserByKey(key string) (models.User, error) { - q := fmt.Sprintf("SELECT user_id, api_key FROM %+s WHERE api_key = ?;", TableUser) - var u models.User - err := srv.Db.QueryRow(q, key).Scan(&u.UserId, &u.ApiKey) - if err != nil { +func (srv *UserService) GetUserByKey(key string) (*models.User, error) { + u := &models.User{} + if err := srv.Db.Where(&models.User{ApiKey: key}).First(u).Error; err != nil { return u, err } return u, nil diff --git a/utils/common.go b/utils/common.go index 8d7c5ed..8ff6903 100644 --- a/utils/common.go +++ b/utils/common.go @@ -3,7 +3,10 @@ package utils import ( "errors" "regexp" + "strings" "time" + + "github.com/n1try/wakapi/models" ) func ParseDate(date string) (time.Time, error) { @@ -22,3 +25,16 @@ func ParseUserAgent(ua string) (string, string, error) { } return groups[0][1], groups[0][2], nil } + +func MakeConnectionString(config *models.Config) string { + str := strings.Builder{} + str.WriteString(config.DbUser) + str.WriteString(":") + str.WriteString(config.DbPassword) + str.WriteString("@tcp(") + str.WriteString(config.DbHost) + str.WriteString(")/") + str.WriteString(config.DbName) + str.WriteString("?charset=utf8&parseTime=true") + return str.String() +}