A Go implementation of Facebook's PDQ
trust-and-safety
pdq
1// Reimplementation of https://github.com/facebook/ThreatExchange/blob/main/pdq in Golang
2//
3// For reference, please see https://github.com/facebook/ThreatExchange/blob/main/hashing/hashing.pdf
4//
5// Function names are similar or the same as those in the reference C++ implementation, and
6// any questions about implementation should reference that code.
7
8package pdq
9
10import (
11 "errors"
12 "fmt"
13 "image"
14 _ "image/gif"
15 _ "image/jpeg"
16 _ "image/png"
17 "math"
18 "os"
19 "time"
20
21 _ "golang.org/x/image/bmp"
22 _ "golang.org/x/image/tiff"
23 _ "golang.org/x/image/webp"
24)
25
26// HashResult contains the output of a PDQ hash operation
27type HashResult struct {
28 Hash string
29 Quality int
30 ImageHeightTimesWidth int
31 HashDuration time.Duration
32}
33
34// Various constants pulled from the reference implementation
35const (
36 LumaFromRCoeff = 0.299
37 LumaFromGCoeff = 0.587
38 LumaFromBCoeff = 0.114
39
40 PdqNumJaroszXYPasses = 2
41
42 DownsampleDims = 512
43
44 MinHashableDim = 5
45)
46
47var (
48 ErrInvalidFile = errors.New("invalid input file name")
49)
50
51// SEE: https://github.com/facebook/ThreatExchange/blob/main/pdq/cpp/hashing/pdqhashing.cpp#L42
52var dctMatrix64 []float32
53
54func init() {
55 const numRows = 16
56 const numCols = 64
57
58 matrixScaleFactor := math.Sqrt(2.0 / float64(numCols))
59
60 dctMatrix64 = make([]float32, numRows*numCols)
61
62 for i := range numRows {
63 for j := range numCols {
64 dctMatrix64[i*numCols+j] = float32(matrixScaleFactor * math.Cos((math.Pi/2.0/float64(numCols))*float64(i+1)*float64(2*j+1)))
65 }
66 }
67}
68
69// HashFromImage generates a PDQ hash from an image.Image
70// The image should idealy be pre-resizes to 512x512 or smaller for performance reasons.
71// SEE: https://github.com/facebook/ThreatExchange/blob/main/hashing/hashing.pdf, "More on Downsampling"
72// Returns a HashResult containing the hash and a quality score between 0 and 100.
73// Please reference the evaluation data for selecting a good quality score. From hashing.pdf:
74// "Confident-match distances are up to the system designer, of course, but 30, 20, or less has been found to
75// produce good results on evaluation data."
76func HashFromImage(img image.Image) (*HashResult, error) {
77 bounds := img.Bounds()
78 size := bounds.Size()
79
80 imageHeightTimesWidth := size.Y * size.X
81
82 luma, numRows, numCols := loadFloatLumaFromImage(img)
83
84 fullBuffer2 := make([]float32, numRows*numCols)
85
86 hashStart := time.Now()
87 hash, quality := hash256FromFloatLuma(luma, fullBuffer2, numRows, numCols)
88 hashTime := time.Since(hashStart)
89
90 return &HashResult{
91 Hash: hash,
92 Quality: quality,
93 ImageHeightTimesWidth: imageHeightTimesWidth,
94 HashDuration: hashTime,
95 }, nil
96}
97
98// Opens a file at the specified file and uses image.Image to decode the image. Returns the result of
99// HashFromImage. This is a convenience wrapper around HashFromImage that handles the IO and decoding for you.
100// Ideally, you should call HashFromImage on your own with a 512x512 or smaller image that you have resized
101// yourself. This function is provided only to match the reference implementation.
102func HashFromFile(filename string) (*HashResult, error) {
103 if filename == "" {
104 return nil, ErrInvalidFile
105 }
106
107 file, err := os.Open(filename)
108 if err != nil {
109 return nil, fmt.Errorf("failed to open file for hashing: %w", err)
110 }
111 defer file.Close()
112
113 img, _, err := image.Decode(file)
114 if err != nil {
115 return nil, fmt.Errorf("failed to decode image: %w", err)
116 }
117
118 return HashFromImage(img)
119}
120
121func loadFloatLumaFromImage(img image.Image) ([]float32, int, int) {
122 bounds := img.Bounds()
123 numRows := bounds.Dy()
124 numCols := bounds.Dx()
125 luma := make([]float32, numRows*numCols)
126
127 for row := range numRows {
128 for col := range numCols {
129 // purposefully discarding alpha
130 r, g, b, _ := img.At(bounds.Min.X+col, bounds.Min.Y+row).RGBA()
131
132 r8 := float32(r >> 8)
133 g8 := float32(g >> 8)
134 b8 := float32(b >> 8)
135
136 luma[row*numCols+col] = LumaFromRCoeff*r8 + LumaFromGCoeff*g8 + LumaFromBCoeff*b8
137 }
138 }
139
140 return luma, numRows, numCols
141}
142
143// SEE: https://github.com/facebook/ThreatExchange/blob/main/pdq/cpp/hashing/pdqhashing.cpp#L127
144func hash256FromFloatLuma(
145 fullBuffer1 []float32,
146 fullBuffer2 []float32,
147 numRows, numCols int,
148) (string, int) {
149 // from reference impl, do not return a hash for images taht are too small
150 if numRows < MinHashableDim || numCols < MinHashableDim {
151 return "", 0
152 }
153
154 buffer64x64 := make([]float32, 64*64)
155 buffer16x64 := make([]float32, 16*64)
156 buffer16x16 := make([]float32, 16*16)
157
158 quality := float256FromFloatLuma(fullBuffer1, fullBuffer2, numRows, numCols, buffer64x64, buffer16x64, buffer16x16)
159
160 hash := convertBufferToHash(buffer16x16)
161
162 return hash, quality
163}
164
165const hexChars = "0123456789abcdef"
166
167func convertBufferToHash(buffer16x16 []float32) string {
168 median := torben(buffer16x16)
169
170 words := make([]uint16, 16)
171
172 for i := range 16 {
173 for j := range 16 {
174 if buffer16x16[i*16+j] > median {
175 bitIndex := i*16 + j
176 wordIndex := bitIndex / 16
177 bitInWord := bitIndex % 16
178 words[wordIndex] |= 1 << bitInWord
179 }
180 }
181 }
182
183 result := make([]byte, 64)
184 for i := range 16 {
185 word := words[15-i]
186 offset := i * 4
187 result[offset+0] = hexChars[word>>12]
188 result[offset+1] = hexChars[(word>>8)&0xF]
189 result[offset+2] = hexChars[(word>>4)&0xF]
190 result[offset+3] = hexChars[word&0xF]
191 }
192
193 return string(result)
194}
195
196// SEE: https://github.com/facebook/ThreatExchange/blob/main/pdq/cpp/hashing/pdqhashing.cpp#L158
197func float256FromFloatLuma(
198 fullBuffer1 []float32,
199 fullBuffer2 []float32,
200 numRows, numCols int,
201 buffer64x64 []float32,
202 buffer16x64 []float32,
203 buffer16x16 []float32,
204) int {
205 if numRows == 64 && numCols == 64 {
206 copy(buffer64x64, fullBuffer1)
207 } else {
208 windowSizeAlongRows := computeJaroszFilterWindowSize(numCols, 64)
209 windowSizeAlongCols := computeJaroszFilterWindowSize(numRows, 64)
210
211 jaroszFilterFloat(fullBuffer1, fullBuffer2, numRows, numCols, windowSizeAlongRows, windowSizeAlongCols, PdqNumJaroszXYPasses)
212
213 decimateFloat(fullBuffer1, numRows, numCols, buffer64x64, 64, 64)
214 }
215
216 quality := imageDomainQualityMetric(buffer64x64)
217
218 dct64To16(buffer64x64, buffer16x64, buffer16x16)
219
220 return quality
221}
222
223// SEE: https://github.com/facebook/ThreatExchange/blob/main/pdq/cpp/hashing/pdqhashing.cpp#L318
224func imageDomainQualityMetric(buffer64x64 []float32) int {
225 gradientSum := 0
226
227 for i := range 63 {
228 for j := range 64 {
229 u := buffer64x64[i*64+j]
230 v := buffer64x64[(i+1)*64+j]
231 d := int(math.Abs(float64((u - v) * 100 / 255)))
232 gradientSum += d
233 }
234 }
235
236 for i := range 64 {
237 for j := range 63 {
238 u := buffer64x64[i*64+j]
239 v := buffer64x64[i*64+j+1]
240 d := int(math.Abs(float64((u - v) * 100 / 255)))
241 gradientSum += d
242 }
243 }
244
245 quality := min(gradientSum/90, 100)
246
247 return quality
248}
249
250// SEE: https://github.com/facebook/ThreatExchange/blob/main/pdq/cpp/hashing/pdqhashing.cpp#L355
251func dct64To16(A []float32, T []float32, B []float32) {
252 for i := range 16 {
253 dctRow := dctMatrix64[i*64:]
254 for j := range 64 {
255 var sum0, sum1, sum2, sum3 float32
256
257 for k := 0; k < 64; k += 4 {
258 sum0 += dctRow[k] * A[k*64+j]
259 sum1 += dctRow[k+1] * A[(k+1)*64+j]
260 sum2 += dctRow[k+2] * A[(k+2)*64+j]
261 sum3 += dctRow[k+3] * A[(k+3)*64+j]
262 }
263
264 T[i*64+j] = sum0 + sum1 + sum2 + sum3
265 }
266 }
267
268 for i := range 16 {
269 tRow := T[i*64:]
270 for j := range 16 {
271 dctRow := dctMatrix64[j*64:]
272 var sum0, sum1, sum2, sum3 float32
273
274 for k := 0; k < 64; k += 4 {
275 sum0 += tRow[k] * dctRow[k]
276 sum1 += tRow[k+1] * dctRow[k+1]
277 sum2 += tRow[k+2] * dctRow[k+2]
278 sum3 += tRow[k+3] * dctRow[k+3]
279 }
280
281 B[i*16+j] = sum0 + sum1 + sum2 + sum3
282 }
283 }
284}
285
286// SEE: https://github.com/facebook/ThreatExchange/blob/main/pdq/cpp/hashing/torben.cpp
287func torben(m []float32) float32 {
288 n := len(m)
289 if n == 0 {
290 return 0
291 }
292
293 min, max := m[0], m[0]
294 for i := 1; i < n; i++ {
295 if m[i] < min {
296 min = m[i]
297 }
298 if m[i] > max {
299 max = m[i]
300 }
301 }
302
303 var guess, maxltguess, mingtguess float32
304 var less, greater, equal int
305
306 for {
307 guess = (min + max) / 2
308 less, greater, equal = 0, 0, 0
309 maxltguess = min
310 mingtguess = max
311
312 for i := range n {
313 if m[i] < guess {
314 less++
315 if m[i] > maxltguess {
316 maxltguess = m[i]
317 }
318 } else if m[i] > guess {
319 greater++
320 if m[i] < mingtguess {
321 mingtguess = m[i]
322 }
323 } else {
324 equal++
325 }
326 }
327
328 if less <= (n+1)/2 && greater <= (n+1)/2 {
329 break
330 } else if less > greater {
331 max = maxltguess
332 } else {
333 min = mingtguess
334 }
335 }
336
337 if less >= (n+1)/2 {
338 return maxltguess
339 } else if less+equal >= (n+1)/2 {
340 return guess
341 } else {
342 return mingtguess
343 }
344}