+148
cmd/repoguard/main.go
+148
cmd/repoguard/main.go
···
1
+
package main
2
+
3
+
import (
4
+
"flag"
5
+
"fmt"
6
+
"log"
7
+
"os"
8
+
"os/exec"
9
+
"path/filepath"
10
+
"strings"
11
+
"time"
12
+
)
13
+
14
+
var (
15
+
logger *log.Logger
16
+
logFile *os.File
17
+
clientIP string
18
+
19
+
// Command line flags
20
+
allowedUser = flag.String("user", "", "Allowed git user")
21
+
baseDirFlag = flag.String("base-dir", "/home/git", "Base directory for git repositories")
22
+
logPathFlag = flag.String("log-path", "/var/log/git-wrapper.log", "Path to log file")
23
+
)
24
+
25
+
func main() {
26
+
flag.Parse()
27
+
28
+
defer cleanup()
29
+
initLogger()
30
+
31
+
// Get client IP from SSH environment
32
+
if connInfo := os.Getenv("SSH_CONNECTION"); connInfo != "" {
33
+
parts := strings.Fields(connInfo)
34
+
if len(parts) > 0 {
35
+
clientIP = parts[0]
36
+
}
37
+
}
38
+
39
+
if *allowedUser == "" {
40
+
exitWithLog("access denied: no user specified")
41
+
}
42
+
43
+
sshCommand := os.Getenv("SSH_ORIGINAL_COMMAND")
44
+
45
+
logEvent("Connection attempt", map[string]interface{}{
46
+
"user": *allowedUser,
47
+
"command": sshCommand,
48
+
"client": clientIP,
49
+
})
50
+
51
+
if sshCommand == "" {
52
+
exitWithLog("access denied: no ssh command provided")
53
+
}
54
+
55
+
cmdParts := strings.Fields(sshCommand)
56
+
if len(cmdParts) < 2 {
57
+
exitWithLog("invalid command format")
58
+
}
59
+
60
+
gitCommand := cmdParts[0]
61
+
repoName := strings.Trim(cmdParts[1], "'")
62
+
63
+
validCommands := map[string]bool{
64
+
"git-receive-pack": true,
65
+
"git-upload-pack": true,
66
+
"git-upload-archive": true,
67
+
}
68
+
if !validCommands[gitCommand] {
69
+
exitWithLog("access denied: invalid git command")
70
+
}
71
+
72
+
if !isAllowedUser(*allowedUser, repoName) {
73
+
exitWithLog("access denied: user not allowed")
74
+
}
75
+
76
+
fullPath := filepath.Join(*baseDirFlag, repoName)
77
+
fullPath = filepath.Clean(fullPath)
78
+
79
+
logEvent("Processing command", map[string]interface{}{
80
+
"user": *allowedUser,
81
+
"command": gitCommand,
82
+
"repo": repoName,
83
+
"fullPath": fullPath,
84
+
"client": clientIP,
85
+
})
86
+
87
+
cmd := exec.Command(gitCommand, fullPath)
88
+
cmd.Stdout = os.Stdout
89
+
cmd.Stderr = os.Stderr
90
+
cmd.Stdin = os.Stdin
91
+
92
+
if err := cmd.Run(); err != nil {
93
+
exitWithLog(fmt.Sprintf("command failed: %v", err))
94
+
}
95
+
96
+
logEvent("Command completed", map[string]interface{}{
97
+
"user": *allowedUser,
98
+
"command": gitCommand,
99
+
"repo": repoName,
100
+
"success": true,
101
+
})
102
+
}
103
+
104
+
func initLogger() {
105
+
var err error
106
+
logFile, err = os.OpenFile(*logPathFlag, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
107
+
if err != nil {
108
+
fmt.Fprintf(os.Stderr, "failed to open log file: %v\n", err)
109
+
os.Exit(1)
110
+
}
111
+
112
+
logger = log.New(logFile, "", 0)
113
+
}
114
+
115
+
func logEvent(event string, fields map[string]interface{}) {
116
+
entry := fmt.Sprintf(
117
+
"timestamp=%q event=%q",
118
+
time.Now().Format(time.RFC3339),
119
+
event,
120
+
)
121
+
122
+
for k, v := range fields {
123
+
entry += fmt.Sprintf(" %s=%q", k, v)
124
+
}
125
+
126
+
logger.Println(entry)
127
+
}
128
+
129
+
func exitWithLog(message string) {
130
+
logEvent("Access denied", map[string]interface{}{
131
+
"error": message,
132
+
})
133
+
logFile.Sync()
134
+
fmt.Fprintf(os.Stderr, "error: %s\n", message)
135
+
os.Exit(1)
136
+
}
137
+
138
+
func cleanup() {
139
+
if logFile != nil {
140
+
logFile.Sync()
141
+
logFile.Close()
142
+
}
143
+
}
144
+
145
+
func isAllowedUser(user, repoPath string) bool {
146
+
pathUser := strings.Split(repoPath, "/")[0]
147
+
return pathUser == user
148
+
}