A Go implementation of Facebook's PDQ
trust-and-safety pdq
at main 344 lines 8.6 kB view raw
1// Reimplementation of https://github.com/facebook/ThreatExchange/blob/main/pdq in Golang 2// 3// For reference, please see https://github.com/facebook/ThreatExchange/blob/main/hashing/hashing.pdf 4// 5// Function names are similar or the same as those in the reference C++ implementation, and 6// any questions about implementation should reference that code. 7 8package pdq 9 10import ( 11 "errors" 12 "fmt" 13 "image" 14 _ "image/gif" 15 _ "image/jpeg" 16 _ "image/png" 17 "math" 18 "os" 19 "time" 20 21 _ "golang.org/x/image/bmp" 22 _ "golang.org/x/image/tiff" 23 _ "golang.org/x/image/webp" 24) 25 26// HashResult contains the output of a PDQ hash operation 27type HashResult struct { 28 Hash string 29 Quality int 30 ImageHeightTimesWidth int 31 HashDuration time.Duration 32} 33 34// Various constants pulled from the reference implementation 35const ( 36 LumaFromRCoeff = 0.299 37 LumaFromGCoeff = 0.587 38 LumaFromBCoeff = 0.114 39 40 PdqNumJaroszXYPasses = 2 41 42 DownsampleDims = 512 43 44 MinHashableDim = 5 45) 46 47var ( 48 ErrInvalidFile = errors.New("invalid input file name") 49) 50 51// SEE: https://github.com/facebook/ThreatExchange/blob/main/pdq/cpp/hashing/pdqhashing.cpp#L42 52var dctMatrix64 []float32 53 54func init() { 55 const numRows = 16 56 const numCols = 64 57 58 matrixScaleFactor := math.Sqrt(2.0 / float64(numCols)) 59 60 dctMatrix64 = make([]float32, numRows*numCols) 61 62 for i := range numRows { 63 for j := range numCols { 64 dctMatrix64[i*numCols+j] = float32(matrixScaleFactor * math.Cos((math.Pi/2.0/float64(numCols))*float64(i+1)*float64(2*j+1))) 65 } 66 } 67} 68 69// HashFromImage generates a PDQ hash from an image.Image 70// The image should idealy be pre-resizes to 512x512 or smaller for performance reasons. 71// SEE: https://github.com/facebook/ThreatExchange/blob/main/hashing/hashing.pdf, "More on Downsampling" 72// Returns a HashResult containing the hash and a quality score between 0 and 100. 73// Please reference the evaluation data for selecting a good quality score. From hashing.pdf: 74// "Confident-match distances are up to the system designer, of course, but 30, 20, or less has been found to 75// produce good results on evaluation data." 76func HashFromImage(img image.Image) (*HashResult, error) { 77 bounds := img.Bounds() 78 size := bounds.Size() 79 80 imageHeightTimesWidth := size.Y * size.X 81 82 luma, numRows, numCols := loadFloatLumaFromImage(img) 83 84 fullBuffer2 := make([]float32, numRows*numCols) 85 86 hashStart := time.Now() 87 hash, quality := hash256FromFloatLuma(luma, fullBuffer2, numRows, numCols) 88 hashTime := time.Since(hashStart) 89 90 return &HashResult{ 91 Hash: hash, 92 Quality: quality, 93 ImageHeightTimesWidth: imageHeightTimesWidth, 94 HashDuration: hashTime, 95 }, nil 96} 97 98// Opens a file at the specified file and uses image.Image to decode the image. Returns the result of 99// HashFromImage. This is a convenience wrapper around HashFromImage that handles the IO and decoding for you. 100// Ideally, you should call HashFromImage on your own with a 512x512 or smaller image that you have resized 101// yourself. This function is provided only to match the reference implementation. 102func HashFromFile(filename string) (*HashResult, error) { 103 if filename == "" { 104 return nil, ErrInvalidFile 105 } 106 107 file, err := os.Open(filename) 108 if err != nil { 109 return nil, fmt.Errorf("failed to open file for hashing: %w", err) 110 } 111 defer file.Close() 112 113 img, _, err := image.Decode(file) 114 if err != nil { 115 return nil, fmt.Errorf("failed to decode image: %w", err) 116 } 117 118 return HashFromImage(img) 119} 120 121func loadFloatLumaFromImage(img image.Image) ([]float32, int, int) { 122 bounds := img.Bounds() 123 numRows := bounds.Dy() 124 numCols := bounds.Dx() 125 luma := make([]float32, numRows*numCols) 126 127 for row := range numRows { 128 for col := range numCols { 129 // purposefully discarding alpha 130 r, g, b, _ := img.At(bounds.Min.X+col, bounds.Min.Y+row).RGBA() 131 132 r8 := float32(r >> 8) 133 g8 := float32(g >> 8) 134 b8 := float32(b >> 8) 135 136 luma[row*numCols+col] = LumaFromRCoeff*r8 + LumaFromGCoeff*g8 + LumaFromBCoeff*b8 137 } 138 } 139 140 return luma, numRows, numCols 141} 142 143// SEE: https://github.com/facebook/ThreatExchange/blob/main/pdq/cpp/hashing/pdqhashing.cpp#L127 144func hash256FromFloatLuma( 145 fullBuffer1 []float32, 146 fullBuffer2 []float32, 147 numRows, numCols int, 148) (string, int) { 149 // from reference impl, do not return a hash for images taht are too small 150 if numRows < MinHashableDim || numCols < MinHashableDim { 151 return "", 0 152 } 153 154 buffer64x64 := make([]float32, 64*64) 155 buffer16x64 := make([]float32, 16*64) 156 buffer16x16 := make([]float32, 16*16) 157 158 quality := float256FromFloatLuma(fullBuffer1, fullBuffer2, numRows, numCols, buffer64x64, buffer16x64, buffer16x16) 159 160 hash := convertBufferToHash(buffer16x16) 161 162 return hash, quality 163} 164 165const hexChars = "0123456789abcdef" 166 167func convertBufferToHash(buffer16x16 []float32) string { 168 median := torben(buffer16x16) 169 170 words := make([]uint16, 16) 171 172 for i := range 16 { 173 for j := range 16 { 174 if buffer16x16[i*16+j] > median { 175 bitIndex := i*16 + j 176 wordIndex := bitIndex / 16 177 bitInWord := bitIndex % 16 178 words[wordIndex] |= 1 << bitInWord 179 } 180 } 181 } 182 183 result := make([]byte, 64) 184 for i := range 16 { 185 word := words[15-i] 186 offset := i * 4 187 result[offset+0] = hexChars[word>>12] 188 result[offset+1] = hexChars[(word>>8)&0xF] 189 result[offset+2] = hexChars[(word>>4)&0xF] 190 result[offset+3] = hexChars[word&0xF] 191 } 192 193 return string(result) 194} 195 196// SEE: https://github.com/facebook/ThreatExchange/blob/main/pdq/cpp/hashing/pdqhashing.cpp#L158 197func float256FromFloatLuma( 198 fullBuffer1 []float32, 199 fullBuffer2 []float32, 200 numRows, numCols int, 201 buffer64x64 []float32, 202 buffer16x64 []float32, 203 buffer16x16 []float32, 204) int { 205 if numRows == 64 && numCols == 64 { 206 copy(buffer64x64, fullBuffer1) 207 } else { 208 windowSizeAlongRows := computeJaroszFilterWindowSize(numCols, 64) 209 windowSizeAlongCols := computeJaroszFilterWindowSize(numRows, 64) 210 211 jaroszFilterFloat(fullBuffer1, fullBuffer2, numRows, numCols, windowSizeAlongRows, windowSizeAlongCols, PdqNumJaroszXYPasses) 212 213 decimateFloat(fullBuffer1, numRows, numCols, buffer64x64, 64, 64) 214 } 215 216 quality := imageDomainQualityMetric(buffer64x64) 217 218 dct64To16(buffer64x64, buffer16x64, buffer16x16) 219 220 return quality 221} 222 223// SEE: https://github.com/facebook/ThreatExchange/blob/main/pdq/cpp/hashing/pdqhashing.cpp#L318 224func imageDomainQualityMetric(buffer64x64 []float32) int { 225 gradientSum := 0 226 227 for i := range 63 { 228 for j := range 64 { 229 u := buffer64x64[i*64+j] 230 v := buffer64x64[(i+1)*64+j] 231 d := int(math.Abs(float64((u - v) * 100 / 255))) 232 gradientSum += d 233 } 234 } 235 236 for i := range 64 { 237 for j := range 63 { 238 u := buffer64x64[i*64+j] 239 v := buffer64x64[i*64+j+1] 240 d := int(math.Abs(float64((u - v) * 100 / 255))) 241 gradientSum += d 242 } 243 } 244 245 quality := min(gradientSum/90, 100) 246 247 return quality 248} 249 250// SEE: https://github.com/facebook/ThreatExchange/blob/main/pdq/cpp/hashing/pdqhashing.cpp#L355 251func dct64To16(A []float32, T []float32, B []float32) { 252 for i := range 16 { 253 dctRow := dctMatrix64[i*64:] 254 for j := range 64 { 255 var sum0, sum1, sum2, sum3 float32 256 257 for k := 0; k < 64; k += 4 { 258 sum0 += dctRow[k] * A[k*64+j] 259 sum1 += dctRow[k+1] * A[(k+1)*64+j] 260 sum2 += dctRow[k+2] * A[(k+2)*64+j] 261 sum3 += dctRow[k+3] * A[(k+3)*64+j] 262 } 263 264 T[i*64+j] = sum0 + sum1 + sum2 + sum3 265 } 266 } 267 268 for i := range 16 { 269 tRow := T[i*64:] 270 for j := range 16 { 271 dctRow := dctMatrix64[j*64:] 272 var sum0, sum1, sum2, sum3 float32 273 274 for k := 0; k < 64; k += 4 { 275 sum0 += tRow[k] * dctRow[k] 276 sum1 += tRow[k+1] * dctRow[k+1] 277 sum2 += tRow[k+2] * dctRow[k+2] 278 sum3 += tRow[k+3] * dctRow[k+3] 279 } 280 281 B[i*16+j] = sum0 + sum1 + sum2 + sum3 282 } 283 } 284} 285 286// SEE: https://github.com/facebook/ThreatExchange/blob/main/pdq/cpp/hashing/torben.cpp 287func torben(m []float32) float32 { 288 n := len(m) 289 if n == 0 { 290 return 0 291 } 292 293 min, max := m[0], m[0] 294 for i := 1; i < n; i++ { 295 if m[i] < min { 296 min = m[i] 297 } 298 if m[i] > max { 299 max = m[i] 300 } 301 } 302 303 var guess, maxltguess, mingtguess float32 304 var less, greater, equal int 305 306 for { 307 guess = (min + max) / 2 308 less, greater, equal = 0, 0, 0 309 maxltguess = min 310 mingtguess = max 311 312 for i := range n { 313 if m[i] < guess { 314 less++ 315 if m[i] > maxltguess { 316 maxltguess = m[i] 317 } 318 } else if m[i] > guess { 319 greater++ 320 if m[i] < mingtguess { 321 mingtguess = m[i] 322 } 323 } else { 324 equal++ 325 } 326 } 327 328 if less <= (n+1)/2 && greater <= (n+1)/2 { 329 break 330 } else if less > greater { 331 max = maxltguess 332 } else { 333 min = mingtguess 334 } 335 } 336 337 if less >= (n+1)/2 { 338 return maxltguess 339 } else if less+equal >= (n+1)/2 { 340 return guess 341 } else { 342 return mingtguess 343 } 344}