A rust implementation of skywatch-phash
1use image::DynamicImage;
2use image_hasher::{HashAlg, HasherConfig, ImageHash};
3use miette::Diagnostic;
4use thiserror::Error;
5
6#[derive(Debug, Error, Diagnostic)]
7pub enum PhashError {
8 #[error("Failed to decode image")]
9 ImageDecode(#[from] image::ImageError),
10
11 #[error("Invalid hash format: {0}")]
12 InvalidHashFormat(String),
13
14 #[error("Invalid hex string")]
15 ParseInt(#[from] std::num::ParseIntError),
16}
17
18/// Compute perceptual hash for an image using average hash (aHash) algorithm
19///
20/// This matches the TypeScript implementation:
21/// 1. Resize to 8x8 (64 pixels)
22/// 2. Convert to grayscale
23/// 3. Compute average pixel value
24/// 4. Create 64-bit binary: 1 if pixel > avg, 0 otherwise
25/// 5. Convert to hex string (16 chars)
26pub fn compute_phash(image_bytes: &[u8]) -> Result<String, PhashError> {
27 let img = image::load_from_memory(image_bytes)?;
28 compute_phash_from_image(&img)
29}
30
31/// Compute perceptual hash from a DynamicImage
32pub fn compute_phash_from_image(img: &DynamicImage) -> Result<String, PhashError> {
33 // Configure hasher with aHash (Mean) algorithm and 8x8 size
34 let hasher = HasherConfig::new()
35 .hash_alg(HashAlg::Mean) // average hash
36 .hash_size(8, 8) // 64 bits
37 .to_hasher();
38
39 // Compute hash
40 let hash = hasher.hash_image(img);
41
42 // Convert to hex string
43 hash_to_hex(&hash)
44}
45
46/// Convert ImageHash to hex string format (16 chars, matching TS output)
47fn hash_to_hex(hash: &ImageHash) -> Result<String, PhashError> {
48 // Get hash bytes
49 let bytes = hash.as_bytes();
50
51 // Convert to hex string
52 let hex = bytes
53 .iter()
54 .map(|b| format!("{:02x}", b))
55 .collect::<String>();
56
57 // Ensure it's 16 characters (64 bits = 8 bytes = 16 hex chars)
58 if hex.len() != 16 {
59 return Err(PhashError::InvalidHashFormat(format!(
60 "Expected 16 hex characters, got {}",
61 hex.len()
62 ))
63 .into());
64 }
65
66 Ok(hex)
67}
68
69/// Compute hamming distance between two phash hex strings
70///
71/// Uses Brian Kernighan's algorithm to count set bits
72pub fn hamming_distance(hash1: &str, hash2: &str) -> Result<u32, PhashError> {
73 // Validate input lengths
74 if hash1.len() != 16 || hash2.len() != 16 {
75 return Err(PhashError::InvalidHashFormat(format!(
76 "Hashes must be 16 hex characters, got {} and {}",
77 hash1.len(),
78 hash2.len()
79 ))
80 .into());
81 }
82
83 let a = u64::from_str_radix(hash1, 16)?;
84 let b = u64::from_str_radix(hash2, 16)?;
85
86 // XOR to find differing bits
87 let xor = a ^ b;
88
89 // Count set bits using Brian Kernighan's algorithm
90 let mut count = 0u32;
91 let mut n = xor;
92 while n > 0 {
93 count += 1;
94 n &= n - 1; // clear the lowest set bit
95 }
96
97 Ok(count)
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn test_hamming_distance_identical() {
106 let hash = "e0e0e0e0e0fcfefe";
107 let distance = hamming_distance(hash, hash).unwrap();
108 assert_eq!(distance, 0);
109 }
110
111 #[test]
112 fn test_hamming_distance_different() {
113 let hash1 = "0000000000000000";
114 let hash2 = "ffffffffffffffff";
115 let distance = hamming_distance(hash1, hash2).unwrap();
116 assert_eq!(distance, 64); // all bits different
117 }
118
119 #[test]
120 fn test_hamming_distance_one_bit() {
121 let hash1 = "0000000000000000";
122 let hash2 = "0000000000000001";
123 let distance = hamming_distance(hash1, hash2).unwrap();
124 assert_eq!(distance, 1);
125 }
126
127 #[test]
128 fn test_hamming_distance_invalid_length() {
129 let hash1 = "e0e0e0e0e0fcfefe";
130 let hash2 = "short";
131 let result = hamming_distance(hash1, hash2);
132 assert!(result.is_err());
133 }
134
135 #[test]
136 fn test_hamming_distance_invalid_hex() {
137 let hash1 = "e0e0e0e0e0fcfefe";
138 let hash2 = "gggggggggggggggg";
139 let result = hamming_distance(hash1, hash2);
140 assert!(result.is_err());
141 }
142
143 #[test]
144 fn test_phash_format() {
145 // Create a simple test image (1x1 black pixel)
146 let img = DynamicImage::new_luma8(1, 1);
147 let hash = compute_phash_from_image(&img).unwrap();
148
149 // Should be 16 hex characters
150 assert_eq!(hash.len(), 16);
151
152 // Should be valid hex
153 u64::from_str_radix(&hash, 16).unwrap();
154 }
155}