# -*- coding: utf-8 -*-
'''
Created on Feb 14, 2013

@author: ymh
'''

from haystack.query import RelatedSearchQuerySet
from haystack import connections
from haystack.exceptions import NotHandled
from haystack.constants import ITERATOR_LOAD_PER_QUERY

class ModelRelatedSearchQuerySet(RelatedSearchQuerySet):
    
    _model = None

    def __init__(self, using=None, query=None, model= None):
        super(ModelRelatedSearchQuerySet, self).__init__(using=using, query=query)
        self._model = model 

    def _fill_cache(self, start, end):
        # Tell the query where to start from and how many we'd like.
        self.query._reset()
        self.query.set_limits(start, end)
        results = self.query.get_results()

        if len(results) == 0:
            return False

        if start is None:
            start = 0

        if end is None:
            end = self.query.get_count()

        # Check if we wish to load all objects.
        if self._load_all:
            original_results = []
            models_pks = {}
            loaded_objects = {}

            # Remember the search position for each result so we don't have to resort later.
            for result in results:
                original_results.append(result)
                models_pks.setdefault(result.model if self._model is None else self._model, []).append(result.pk)

            # Load the objects for each model in turn.
            for model in models_pks:
                if model in self._load_all_querysets:
                    # Use the overriding queryset.
                    loaded_objects[model] = self._load_all_querysets[model].in_bulk(models_pks[model])
                else:
                    # Check the SearchIndex for the model for an override.
                    try:
                        index = connections[self.query._using].get_unified_index().get_index(model)
                        qs = index.load_all_queryset()
                        loaded_objects[model] = qs.in_bulk(models_pks[model])
                    except NotHandled:
                        # The model returned doesn't seem to be handled by the
                        # routers. We should silently fail and populate
                        # nothing for those objects.
                        loaded_objects[model] = []

        if len(results) + len(self._result_cache) < len(self) and len(results) < ITERATOR_LOAD_PER_QUERY:
            self._ignored_result_count += ITERATOR_LOAD_PER_QUERY - len(results)

        for result in results:
            if self._load_all:
                # We have to deal with integer keys being cast from strings; if this
                # fails we've got a character pk.
                try:
                    result.pk = int(result.pk)
                except ValueError:
                    pass
                try:
                    result._object = loaded_objects[result.model if self._model is None else self._model][result.pk]
                except (KeyError, IndexError):
                    # The object was either deleted since we indexed or should
                    # be ignored; fail silently.
                    self._ignored_result_count += 1
                    continue

            self._result_cache.append(result)

        return True

    def _clone(self, klass=None):
        if klass is None:
            klass = self.__class__

        query = self.query._clone()
        clone = klass(query=query)
        clone._load_all = self._load_all
        clone._load_all_querysets = self._load_all_querysets
        clone._model = self._model
        return clone

