from django import forms
from haystack.forms import SearchForm
from iconolab.models import Item, Annotation, Collection




class IconolabSearchForm(SearchForm):

    model_type = forms.ChoiceField(required=False, choices=(("images","Images"), ("annotations","Annotations")) )
    tags = forms.BooleanField(required=False, initial=False)

    def __init__(self, *args, **kwargs):
        self.collection_name = kwargs.pop("collection_name")
        if self.collection_name and Collection.objects.filter(name=self.collection_name).exists():
            self.collection = Collection.objects.get(name=self.collection_name)
        selected_model_type =  kwargs.pop("model_type", None)

        if selected_model_type is not None:
            data = kwargs.get("data", None)
            if data:
                data = data.copy()
                data["model_type"] = selected_model_type
                kwargs['data'] = data

        super(IconolabSearchForm, self).__init__(*args, **kwargs)

    def no_query_found(self):
        # load all
        selected_type = self.cleaned_data.get("model_type")
        qs = self.get_model_type_queryset(self.searchqueryset, selected_type).load_all()
        return qs

    def get_model_type_queryset(self, qs, model_type, tags_only):

        if model_type == 'images':
            qs = qs.models(Item).load_all_queryset(Item, Item.objects.select_related('collection', 'metadatas'))
        if model_type == 'annotations':
            qs = qs.models(Annotation).load_all_queryset(Annotation, Annotation.objects.select_related('image', 'image__item', 'image__item__collection', 'stats', 'current_revision', 'author'))

        if self.collection_name is not None:
            qs = qs.filter(collection = self.collection_name)

        if tags_only:
            qs = qs.filter(tags=self.cleaned_data.get("q"))

        return qs

    def search(self):
        selected_type = self.cleaned_data.get("model_type")
        tags_only = self.cleaned_data.get("tags")

        qs = super(IconolabSearchForm, self).search()

        if qs.count() == 0:
            return qs
        else:
            qs = self.get_model_type_queryset(qs, selected_type, tags_only).load_all()
        return qs
