Source code for afew.filters.BaseFilter

#  -*- coding: utf-8 -*-
# SPDX-License-Identifier: ISC
# Copyright (c) Justus Winter <>

from __future__ import print_function, absolute_import, unicode_literals

import collections
import logging

import notmuch

[docs]class Filter(object): message = 'No message specified for filter' tags = [] tags_blacklist = [] def __init__(self, database, **kwargs): super(Filter, self).__init__() self.log = logging.getLogger('{}.{}'.format( self.__module__, self.__class__.__name__)) self.database = database if 'tags' not in kwargs: kwargs['tags'] = self.tags for key, value in kwargs.items(): setattr(self, key, value) self.flush_changes() self._tags_to_add = [] self._tags_to_remove = [] for tag_action in self.tags: if tag_action[0] not in '+-': raise ValueError('Each tag must be preceded by either + or -') (self._tags_to_add if tag_action[0] == '+' else self._tags_to_remove).append(tag_action[1:]) self._tag_blacklist = set(self.tags_blacklist)
[docs] def flush_changes(self): ''' (Re)Initializes the data structures that hold the enqueued changes to the notmuch database. ''' self._add_tags = collections.defaultdict(lambda: set()) self._remove_tags = collections.defaultdict(lambda: set()) self._flush_tags = []
def run(self, query): if getattr(self, 'query', None): if query: query = '(%s) AND (%s)' % (query, self.query) else: query = self.query for message in self.database.get_messages(query): self.handle_message(message) def handle_message(self, message): if not self._tag_blacklist.intersection(message.get_tags()): self.remove_tags(message, *self._tags_to_remove) self.add_tags(message, *self._tags_to_add) def add_tags(self, message, *tags): if tags: self.log.debug('Adding tags %s to id:%s' % (', '.join(tags), message.get_message_id())) self._add_tags[message.get_message_id()].update(tags) def remove_tags(self, message, *tags): if tags: filtered_tags = list(tags) try: filtered_tags.remove('unread') filtered_tags.remove('read') except ValueError: pass self.log.debug('Removing tags %s from id:%s' % (', '.join(filtered_tags), message.get_message_id())) self._remove_tags[message.get_message_id()].update(filtered_tags) def flush_tags(self, message): self.log.debug('Removing all tags from id:%s' % message.get_message_id()) self._flush_tags.append(message.get_message_id()) def commit(self, dry_run=True): dirty_messages = set() dirty_messages.update(self._flush_tags) dirty_messages.update(self._add_tags.keys()) dirty_messages.update(self._remove_tags.keys()) if not dirty_messages: return if dry_run:'I would commit changes to %i messages' % len(dirty_messages)) else:'Committing changes to %i messages' % len(dirty_messages)) db = for message_id in dirty_messages: message = db.find_message(message_id) if message_id in self._flush_tags: message.remove_all_tags() for tag in self._add_tags.get(message_id, []): message.add_tag(tag) for tag in self._remove_tags.get(message_id, []): message.remove_tag(tag) self.flush_changes()