"""
Manages the FAISS vector index and essential metadata for MemVid.
- Stores essential metadata for fast search.
- Maintains bidirectional mapping between document IDs and video frame numbers for efficient deletion.
This module provides functionality for managing vector indices used in LangChain MemVid,
including FAISS index creation, updating, and searching.
"""
import faiss
import numpy as np
import hashlib
from pathlib import Path
from typing import List, Optional, Dict, Any
from dataclasses import dataclass
import orjson
from .exceptions import MemVidIndexError
from .config import IndexConfig, LANGCHAIN_MEMVID_DEFAULT_INDEX_DIR
from .utils import ProgressDisplay
from .logging import get_logger
from .types import FrameMappingStats
logger = get_logger("index")
[docs]
@dataclass
class SearchResult:
"""Represents a search result with metadata and similarity score."""
text: str
source: Optional[str] = None
category: Optional[str] = None
similarity: float = 0.0
metadata: Optional[Dict[str, Any]] = None
[docs]
class IndexManager:
"""Manages vector indices for MemVid.
This index manager implements a hybrid storage approach that optimizes storage efficiency
while maintaining performance and data integrity.
Hybrid Storage Approach
- Essential Metadata Only: Stores only essential metadata in FAISS for efficiency
- Document text, source, category, doc_id, metadata_hash
- Significant reduction in FAISS index size compared to full metadata storage
- Fast search operations with minimal memory usage
- Full Metadata in Video: Complete metadata stored in video QR codes
- All metadata fields and custom attributes
- Complete backup and archive functionality
- On-demand retrieval when needed
Optimization Strategies for Document Deletion
The index manager implements optimized deletion strategies to avoid full video rebuilds:
Frame Index Mapping
- Maintains bidirectional mapping between document IDs and frame numbers
- Enables O(1) lookup for frame numbers given document IDs
- Allows precise frame-level deletion without full video rebuilds
Performance Characteristics
- Search Performance: Sub-second search with essential metadata
- Storage Efficiency: Significant reduction in FAISS index size
- Deletion Performance: O(k) time complexity where k = frames to delete
- Memory Usage: Optimized for large-scale operations
Best Practices
- Batch Operations: Add or delete multiple documents at once for better efficiency
- Frame Mapping: Monitor frame mapping integrity for optimal deletion performance
- Metadata Management: Use essential metadata for search, full metadata for details
- Error Handling: Implement fallback mechanisms for corrupted data
"""
[docs]
def __init__(
self,
config: IndexConfig,
embeddings: Any,
):
"""Initialize the index manager.
Args:
config: Configuration for the index
embeddings: LangChain embeddings interface
"""
self.config = config
self.embeddings = embeddings
self._index: Optional[faiss.Index] = None
self._metadata: List[Dict[str, Any]] = []
self._is_trained: bool = False
self._dimension: Optional[int] = None
self._min_points: Optional[int] = None
self._progress = ProgressDisplay(show_progress=config.show_progress)
# Frame mapping for efficient deletion
self._frame_mapping: Dict[int, int] = {} # doc_id -> frame_number
self._reverse_frame_mapping: Dict[int, int] = {} # frame_number -> doc_id
[docs]
def create_index(self):
"""Create a new FAISS index based on the embeddings model."""
try:
# Get dimension from embeddings
test_vector = self.embeddings.embed_query("test")
self._dimension = len(test_vector)
match self.config.index_type:
case "faiss":
match self.config.metric:
case "cosine" | "ip":
self._index = faiss.IndexFlatIP(self._dimension)
case "l2":
self._index = faiss.IndexFlatL2(self._dimension)
case _:
raise MemVidIndexError(f"Unsupported metric: {self.config.metric}")
# If using IVF index
if self.config.nlist > 0:
# FAISS requires at least 30 * nlist points for training
self._min_points = 30 * self.config.nlist
# Use flat index if minimum points is too high
if self._min_points > 1000:
logger.warning(
f"Minimum points required ({self._min_points}) is too high. "
"Falling back to flat index."
)
self._is_trained = True
else:
# We'll convert to IVF when we have enough points
self._is_trained = True
else:
self._is_trained = True
case _:
raise MemVidIndexError(f"Unsupported index type: {self.config.index_type}")
logger.info(f"Created {self.config.index_type} index with {self.config.metric} metric")
except Exception as e:
raise MemVidIndexError(f"Failed to create index: {str(e)}")
[docs]
def add_texts(
self,
texts: List[str],
metadata: Optional[List[Dict[str, Any]]] = None,
):
"""Add texts and essential metadata to the index."""
try:
if self._index is None:
self.create_index()
# Convert texts to vectors using embeddings with progress bar
logger.info(f"Embedding {len(texts)} texts...")
with self._progress.progress(total=1, desc="Embedding texts") as pbar:
vectors = np.array(self.embeddings.embed_documents(texts), dtype='float32')
pbar.update(1)
# Use empty metadata if none provided
if metadata is None:
metadata = [{"text": text} for text in texts]
else:
# Ensure each metadata dict has the original text
for i, text in enumerate(texts):
if "text" not in metadata[i]:
metadata[i]["text"] = text
# Create a mapping of text to index for deduplication
text_to_idx = {m["text"]: i for i, m in enumerate(self._metadata)}
# Filter out duplicates and keep track of which texts to add
unique_indices = []
for i, text in self._progress.tqdm(enumerate(texts), desc="Deduplicating texts", total=len(texts)):
if text not in text_to_idx:
unique_indices.append(i)
text_to_idx[text] = len(self._metadata) + len(unique_indices) - 1
if not unique_indices:
logger.info("No new texts to add - all were duplicates")
return
# Filter vectors and metadata to only include unique texts
unique_vectors = vectors[unique_indices]
unique_metadata = [metadata[i] for i in unique_indices]
# Store only essential metadata in FAISS for efficiency
# Full metadata will be stored in video QR codes
essential_metadata = []
for i, meta in enumerate(unique_metadata):
# Create a stable hash of the metadata for integrity checking
metadata_str = str(sorted(meta.items()))
metadata_hash = hashlib.sha256(metadata_str.encode('utf-8')).hexdigest()
essential_meta = {
"text": meta.get("text", ""),
"id": len(self._metadata) + i, # Document ID for mapping
# Store only frequently accessed fields
"source": meta.get("source"),
"category": meta.get("category"),
# Store a hash of the full metadata for integrity checking
"metadata_hash": metadata_hash
}
essential_metadata.append(essential_meta)
# Check if we should convert to IVF index
if (
self.config.nlist > 0
and not isinstance(self._index, faiss.IndexIVFFlat)
and self._index.ntotal + len(unique_vectors) >= self._min_points
):
logger.info("Converting to IVF index...")
# Create IVF index
quantizer = self._index
metric = (
faiss.METRIC_INNER_PRODUCT if self.config.metric == "cosine"
else faiss.METRIC_L2
)
self._index = faiss.IndexIVFFlat(
quantizer,
self._dimension,
self.config.nlist,
metric
)
# Get all existing vectors with progress bar
batch_size = 1000
all_vectors = np.zeros(
(self._index.ntotal, self._dimension),
dtype='float32'
)
for i in self._progress.tqdm(
range(0, self._index.ntotal, batch_size),
desc="Reconstructing vectors"
):
end_idx = min(i + batch_size, self._index.ntotal)
self._index.reconstruct_n(i, end_idx - i, all_vectors[i:end_idx])
# Train the index
logger.info("Training IVF index...")
self._index.train(all_vectors)
self._is_trained = True
# Add back the vectors with progress bar
for i in self._progress.tqdm(
range(0, len(all_vectors), batch_size),
desc="Adding vectors to IVF index"
):
batch = all_vectors[i:i + batch_size]
self._index.add(batch)
logger.info(f"Converted to IVF index and trained with {self._index.ntotal} points")
# Check if IVF index needs training
if isinstance(self._index, faiss.IndexIVFFlat) and not self._is_trained:
# Train the index with these vectors
logger.info("Training IVF index...")
self._index.train(unique_vectors)
self._is_trained = True
# Normalize vectors for cosine similarity
if self.config.metric == "cosine":
faiss.normalize_L2(unique_vectors)
# Add vectors to index in batches with progress bar
batch_size = 1000
for i in self._progress.tqdm(range(0, len(unique_vectors), batch_size), desc="Adding vectors to index"):
batch = unique_vectors[i:i + batch_size]
self._index.add(batch)
self._metadata.extend(essential_metadata)
logger.info(f"Added {len(unique_vectors)} unique texts to index")
except Exception as e:
raise MemVidIndexError(f"Failed to add texts: {str(e)}")
[docs]
def search_text(
self,
query_text: str,
k: int = 4,
) -> List[SearchResult]:
"""Search for similar texts using a text query."""
try:
if self._index is None:
raise MemVidIndexError("Index not initialized")
# Convert query text to vector
query_vector = np.array(self.embeddings.embed_query(query_text), dtype='float32').reshape(1, -1)
# Normalize query vector for cosine similarity
if self.config.metric == "cosine":
faiss.normalize_L2(query_vector)
# Search using the vector
distances, indices = self._index.search(query_vector, k)
# Create SearchResult objects
results = []
for idx, distance in zip(indices[0], distances[0]):
metadata = self._metadata[idx]
results.append(SearchResult.from_metadata(metadata, float(distance)))
return results
except Exception as e:
raise MemVidIndexError(f"Failed to search text: {str(e)}")
[docs]
def delete_by_ids(self, doc_ids: List[int]) -> bool:
"""Delete documents by their IDs and update index and mappings."""
try:
if self._index is None:
raise MemVidIndexError("Index not initialized")
if not doc_ids:
return False
# Validate IDs
max_id = len(self._metadata) - 1
invalid_ids = [doc_id for doc_id in doc_ids if doc_id < 0 or doc_id > max_id]
if invalid_ids:
raise MemVidIndexError(f"Invalid document IDs: {invalid_ids}")
# Sort IDs in descending order to avoid index shifting issues
doc_ids = sorted(doc_ids, reverse=True)
# Remove from metadata first
for doc_id in doc_ids:
del self._metadata[doc_id]
# Rebuild the index without the deleted vectors
self._rebuild_index_without_deleted(doc_ids)
logger.info(f"Deleted {len(doc_ids)} documents from index")
return True
except Exception as e:
raise MemVidIndexError(f"Failed to delete documents: {str(e)}")
[docs]
def delete_by_texts(self, texts: List[str]) -> bool:
"""Delete documents by their text content."""
try:
if self._index is None:
raise MemVidIndexError("Index not initialized")
if not texts:
return False
# Find document IDs by text content
doc_ids = []
for text in texts:
for i, metadata in enumerate(self._metadata):
if metadata.get("text") == text:
doc_ids.append(i)
break # Only delete first occurrence of each text
if not doc_ids:
logger.info("No documents found with the specified texts")
return False
return self.delete_by_ids(doc_ids)
except Exception as e:
raise MemVidIndexError(f"Failed to delete documents by texts: {str(e)}")
[docs]
def _rebuild_index_without_deleted(self, deleted_ids: List[int]):
"""Rebuild the index after deleting specified document IDs.
Args:
deleted_ids: List of document IDs that were deleted (in descending order)
Raises:
MemVidIndexError: If rebuilding fails
"""
try:
# Create a new index of the same type
self._index = None
self.create_index()
# If we have remaining documents, rebuild the index
if self._metadata:
# Get all remaining texts
remaining_texts = [metadata.get("text", "") for metadata in self._metadata]
# Convert texts to vectors
vectors = np.array(self.embeddings.embed_documents(remaining_texts), dtype='float32')
# Normalize vectors for cosine similarity
if self.config.metric == "cosine":
faiss.normalize_L2(vectors)
# Add vectors to the new index
self._index.add(vectors)
# Reassign document IDs to be sequential starting from 0
for i, metadata in enumerate(self._metadata):
metadata["id"] = i
logger.info(f"Rebuilt index with {len(self._metadata)} remaining documents")
except Exception as e:
raise MemVidIndexError(f"Failed to rebuild index: {str(e)}")
[docs]
def get_all_documents(self) -> List[Dict[str, Any]]:
"""Get all documents in the index.
Returns:
List of all document metadata dictionaries
Raises:
MemVidIndexError: If retrieval fails
"""
try:
return self._metadata.copy()
except Exception as e:
raise MemVidIndexError(f"Failed to get all documents: {str(e)}")
[docs]
def get_document_count(self) -> int:
"""Get the total number of documents in the index.
Returns:
Number of documents in the index
"""
return len(self._metadata) if self._metadata else 0
[docs]
def set_frame_mapping(self, doc_id: int, frame_number: int):
"""Set the frame mapping for a document.
This method establishes the bidirectional mapping between document IDs and frame numbers,
which is essential for optimized deletion strategies.
Frame Index Mapping
- Bidirectional Mapping: doc_id frame_number for efficient lookups
- O(1) Lookup: Enables constant-time frame number retrieval
- Deletion Optimization: Allows precise frame-level deletion without full video rebuilds
- Consistency: Maintains synchronization between FAISS index and video frames
Performance Benefits
- Fast Deletion: O(k) time complexity where k = frames to delete
- Memory Efficient: Minimal memory overhead for mapping storage
- Scalable: Efficient for large document collections
- Reliable: Provides fallback mechanisms when mappings are corrupted
Use Cases
- Optimized Deletion: Enables selective frame removal from videos
- Frame Lookup: Fast retrieval of frame numbers for document IDs
- Document Lookup: Fast retrieval of document IDs for frame numbers
- Statistics: Provides mapping coverage statistics for monitoring
Args:
doc_id: Document ID
frame_number: Frame number in the video
Example:
# Set frame mapping for a document
index_manager.set_frame_mapping(123, 5)
# Retrieve frame number
frame_num = index_manager.get_frame_number(123) # Returns 5
# Retrieve document ID
doc_id = index_manager.get_document_id(5) # Returns 123
"""
self._frame_mapping[doc_id] = frame_number
self._reverse_frame_mapping[frame_number] = doc_id
[docs]
def get_frame_number(self, doc_id: int) -> Optional[int]:
"""Get the frame number for a document.
Args:
doc_id: Document ID
Returns:
Frame number if found, None otherwise
"""
return self._frame_mapping.get(doc_id)
[docs]
def get_document_id(self, frame_number: int) -> Optional[int]:
"""Get the document ID for a frame.
Args:
frame_number: Frame number
Returns:
Document ID if found, None otherwise
"""
return self._reverse_frame_mapping.get(frame_number)
[docs]
def get_frames_to_delete(self, doc_ids: List[int]) -> List[int]:
"""Get frame numbers that need to be deleted for given document IDs.
This method is a key component of the optimized deletion strategy, enabling
precise frame-level deletion without full video rebuilds.
Optimization Strategy
- Frame Mapping Lookup: Uses O(1) lookup to find frame numbers for document IDs
- Safe Deletion Order: Returns frames in reverse order for safe deletion
- Efficient Processing: Processes multiple document IDs in a single operation
- Error Handling: Gracefully handles missing frame mappings
Performance Characteristics
- Lookup Time: O(k) where k = number of document IDs
- Memory Usage: Minimal temporary storage for frame numbers
- Scalability: Efficient for large-scale deletions
- Reliability: Handles missing mappings gracefully
Use Cases
- Video Frame Removal: Provides frame numbers for selective video editing
- Optimized Deletion: Enables efficient document removal without full rebuilds
- Batch Processing: Supports deletion of multiple documents at once
- Statistics: Provides data for deletion performance analysis
Args:
doc_ids: List of document IDs to delete
Returns:
List of frame numbers to delete, sorted in reverse order for safe deletion
Example:
# Get frames to delete for multiple documents
doc_ids = [0, 5, 10]
frames_to_delete = index_manager.get_frames_to_delete(doc_ids)
print(f"Frames to delete: {frames_to_delete}") # e.g., [10, 5, 0]
# Use frames for video editing
video_processor.remove_frames_from_video(video_path, frames_to_delete)
"""
frames = []
for doc_id in doc_ids:
frame_number = self._frame_mapping.get(doc_id)
if frame_number is not None:
frames.append(frame_number)
return sorted(frames, reverse=True) # Sort in descending order for safe deletion
[docs]
def delete_frames_from_mapping(self, frame_numbers: List[int]):
"""Remove frame mappings for deleted frames.
Args:
frame_numbers: List of frame numbers that were deleted
"""
for frame_number in frame_numbers:
doc_id = self._reverse_frame_mapping.pop(frame_number, None)
if doc_id is not None:
self._frame_mapping.pop(doc_id, None)
[docs]
def get_frame_mapping_stats(self) -> FrameMappingStats:
"""Get statistics about frame mappings for monitoring and optimization.
Returns:
FrameMappingStats: Statistics about frame mappings.
"""
return FrameMappingStats(
total_documents=len(self._metadata),
mapped_documents=len(self._frame_mapping),
mapping_coverage=(len(self._frame_mapping) / len(self._metadata)) * 100 if self._metadata else 0,
mapping_efficiency={
"total_frames": len(self._reverse_frame_mapping),
"frame_range": {
"min": min(self._reverse_frame_mapping.keys()) if self._reverse_frame_mapping else None,
"max": max(self._reverse_frame_mapping.keys()) if self._reverse_frame_mapping else None
}
}
)
[docs]
def save(self, path: Path = LANGCHAIN_MEMVID_DEFAULT_INDEX_DIR):
"""Save the index and metadata to disk.
Args:
path: Path to save the index, defaults to LANGCHAIN_MEMVID_DEFAULT_INDEX_DIR
Raises:
MemVidIndexError: If saving fails
"""
try:
if self._index is None:
raise MemVidIndexError("No index to save")
# Create directory if it doesn't exist
path.mkdir(parents=True, exist_ok=True)
# Save FAISS index
faiss.write_index(self._index, str(path / "index.faiss"))
# Save metadata
with open(path / "metadata.json", "wb") as f:
f.write(orjson.dumps(self._metadata, option=orjson.OPT_NON_STR_KEYS))
logger.info(f"Saved index to {path}")
except Exception as e:
raise MemVidIndexError(f"Failed to save index: {str(e)}")
[docs]
def load(self, path: Path = LANGCHAIN_MEMVID_DEFAULT_INDEX_DIR):
"""Load the index and metadata from disk.
Args:
path: Path to load the index from, defaults to LANGCHAIN_MEMVID_DEFAULT_INDEX_DIR
Raises:
MemVidIndexError: If loading fails
"""
try:
if not path.exists() and not path.is_dir():
raise FileNotFoundError(f"Path {path} does not exist or is not a directory")
index_file = path / "index.faiss"
metadata_file = path / "metadata.json"
if not index_file.exists():
raise FileNotFoundError(f"Index file not found at {index_file}")
if not metadata_file.exists():
raise FileNotFoundError(f"Metadata file not found at {metadata_file}")
# Load FAISS index
self._index = faiss.read_index(str(index_file))
self._dimension = self._index.d
# Load metadata
with open(metadata_file, "rb") as f:
self._metadata = orjson.loads(f.read())
logger.info(f"Loaded index from {path}")
except Exception as e:
raise MemVidIndexError(f"Failed to load index: {str(e)}")