from atproto.models import ( Facet, FacetFeature, FacetIndex, LinkFeature, MentionFeature, TagFeature, ) from cross.tokens import LinkToken, MentionToken, TagToken, TextToken, Token from util.splitter import canonical_label def richtext_to_tokens(text: str, facets: list[dict]) -> list[Token]: if not text: return [] ut8_text = text.encode("utf-8") if not facets: return [TextToken(text=ut8_text.decode("utf-8"))] slices: list[tuple[int, int, str, str]] = [] for facet in facets: features: list[dict] = facet.get("features", []) if not features: continue feature = features[0] feature_type = feature["$type"] index = facet["index"] match feature_type: case "app.bsky.richtext.facet#tag": slices.append( (index["byteStart"], index["byteEnd"], "tag", feature["tag"]) ) case "app.bsky.richtext.facet#link": slices.append( (index["byteStart"], index["byteEnd"], "link", feature["uri"]) ) case "app.bsky.richtext.facet#mention": slices.append( (index["byteStart"], index["byteEnd"], "mention", feature["did"]) ) if not slices: return [TextToken(text=ut8_text.decode("utf-8"))] slices.sort(key=lambda s: s[0]) unique: list[tuple[int, int, str, str]] = [] current_end = 0 for start, end, ttype, val in slices: if start >= current_end: unique.append((start, end, ttype, val)) current_end = end if not unique: return [TextToken(text=ut8_text.decode("utf-8"))] tokens: list[Token] = [] prev = 0 for start, end, ttype, val in unique: if start > prev: tokens.append(TextToken(text=ut8_text[prev:start].decode("utf-8"))) match ttype: case "link": label = ut8_text[start:end].decode("utf-8") split = val.split("://", 1) if ( len(split) > 1 and split[1].startswith(label) or (label.endswith("...") and split[1].startswith(label[:-3])) ): tokens.append(LinkToken(href=val)) prev = end continue tokens.append(LinkToken(href=val, label=label)) case "tag": tag = ut8_text[start:end].decode("utf-8") tokens.append(TagToken(tag=tag[1:] if tag.startswith("#") else tag)) case "mention": mention = ut8_text[start:end].decode("utf-8") tokens.append( MentionToken( username=mention[1:] if mention.startswith("@") else mention, uri=val, ) ) prev = end if prev < len(ut8_text): tokens.append(TextToken(text=ut8_text[prev:].decode("utf-8"))) return tokens def tokens_to_richtext(tokens: list[Token]) -> tuple[str, list[Facet]] | None: segments: list[tuple[str, FacetFeature | None]] = [] byte_offset = 0 for token in tokens: match token: case TextToken(): text_bytes = token.text.encode("utf-8") segments.append((token.text, None)) byte_offset += len(text_bytes) case TagToken(): tag_text = f"#{token.tag}" tag_bytes = tag_text.encode("utf-8") segments.append( ( tag_text, TagFeature(tag=token.tag), ) ) byte_offset += len(tag_bytes) case MentionToken(): mention_text = f"@{token.username}" mention_bytes = mention_text.encode("utf-8") segments.append( ( mention_text, MentionFeature(did=token.uri) if token.uri else MentionFeature(did=""), ) ) byte_offset += len(mention_bytes) case LinkToken(): href = token.href label = token.label if token.label else href if canonical_label(token.label, token.href): max_label_len = 30 label_bytes = label.encode("utf-8") if len(label_bytes) > max_label_len: label = label[: max_label_len - 1] + "…" label_bytes = label.encode("utf-8") else: label_bytes = label.encode("utf-8") segments.append( ( label, LinkFeature(uri=href), ) ) byte_offset += len(label_bytes) case _: return None text = "".join(seg[0] for seg in segments) facets: list[Facet] = [] current_offset = 0 for seg_text, seg_feature in segments: if seg_feature: seg_bytes = seg_text.encode("utf-8") facets.append( Facet( index=FacetIndex( byte_start=current_offset, byte_end=current_offset + len(seg_bytes), ), features=[seg_feature], ) ) current_offset += len(seg_text.encode("utf-8")) return text, facets