"""
Programowanie obiektowe w Pythonie 3; Studium przypadku

Rozdział 10., Wzorzec Iterator
"""
from __future__ import annotations
import bisect
import heapq
import collections
from typing import cast, NamedTuple, Callable, Iterable, List, Union, Counter


class Sample(NamedTuple):
    sepal_length: float
    sepal_width: float
    petal_length: float
    petal_width: float


class KnownSample(NamedTuple):
    sample: Sample
    species: str


class TestingKnownSample(NamedTuple):
    sample: KnownSample


class TrainingKnownSample(NamedTuple):
    sample: KnownSample


TestingList = List[TestingKnownSample]
TrainingList = List[TrainingKnownSample]


class UnknownSample(NamedTuple):
    sample: Sample


class ClassifiedKnownSample(NamedTuple):
    sample: KnownSample
    classification: str


class ClassifiedUnknownSample(NamedTuple):
    sample: UnknownSample
    classification: str


AnySample = Union[KnownSample, UnknownSample]
DistanceFunc = Callable[[TrainingKnownSample, AnySample], float]


class Measured(NamedTuple):
    """Zmierzona odległość jest zapisywana jako pierwsza, by uprościć sortowanie"""

    distance: float
    sample: TrainingKnownSample


import itertools
from typing import DefaultDict, Tuple, Iterator

ModuloDict = DefaultDict[int, List[KnownSample]]


def partition_2(
    samples: Iterable[KnownSample], training_rule: Callable[[int], bool]
) -> tuple[TrainingList, TestingList]:
    """Rozdzielamy na różne kubełki.
    Kubełki są łączone w 2 zbiory: testowy i uczący.
    Zastosowano listy.

    Wyniki są w dużym stopniu uzależnione od losowości ziarna; można skorzystać
    ze zmiennej środowiskowej PYTHONHASHSEED=0
    """
    rule_multiple = 60
    partitions: ModuloDict = collections.defaultdict(list)
    for s in samples:
        partitions[hash(s) % rule_multiple].append(s)

    training_partitions: list[Iterator[TrainingKnownSample]] = []
    testing_partitions: list[Iterator[TestingKnownSample]] = []
    for i, p in enumerate(partitions.values()):
        if training_rule(i):
            training_partitions.append(TrainingKnownSample(s) for s in p)
        else:
            testing_partitions.append(TestingKnownSample(s) for s in p)

    training = list(itertools.chain(*training_partitions))
    testing = list(itertools.chain(*testing_partitions))
    return training, testing


test_partition_2 = """
Dla niektórych wartości ziarna oraz bardzo mały pul danych,
*mogą* wystąpić kolizje.
Aby móc korzystać z małych zbiorów danych i zminimalizować ryzyko
wystąpienia kolizji zastosowaliśmy liczby pierwsze mniejsze od 250

>>> data = [
...     KnownSample(sample=Sample(2, 3, 5, 7), species="C"),
...     KnownSample(sample=Sample(11, 13, 17, 19), species="G"),
...     KnownSample(sample=Sample(23, 29, 31, 37), species="I"),
...     KnownSample(sample=Sample(41, 43, 47, 53), species="O"),
... ]

Reguła 25%/75%
>>> train, test = partition_2(data, lambda i: i % 4 != 0)
>>> len(train)
3
>>> len(test)
1

"""


def k_nn_1(
    k: int, dist: DistanceFunc, training_data: TrainingList, unknown: AnySample
) -> str:
    """
    1.	Utworzenie listy wszystkich par (odległość, próbka ucząca).
    
    2.	Posortowanie tych par w kolejności malejącej.
    
    3.	Wybranie pierwszych k par, czyli k najbliższych sąsiadów.
    
    4.	Wybranie etykiety dominanty (najwyższej częstotliwości) dla k najbliższych sąsiadów.
    

    >>> data = [
    ...     TrainingKnownSample(KnownSample(sample=Sample(1, 2, 3, 4), species="a")),
    ...     TrainingKnownSample(KnownSample(sample=Sample(2, 3, 4, 5), species="b")),
    ...     TrainingKnownSample(KnownSample(sample=Sample(3, 4, 5, 6), species="c")),
    ...     TrainingKnownSample(KnownSample(sample=Sample(4, 5, 6, 7), species="d")),
    ... ]
    >>> dist = lambda ts1, u2: max(abs(ts1.sample.sample[i] - u2.sample[i]) for i in range(len(ts1)))
    >>> k_nn_1(1, dist, data, UnknownSample(Sample(1.1, 2.1, 3.1, 4.1)))
    'a'
    """
    distances = sorted(map(lambda t: Measured(dist(t, unknown), t), training_data))
    k_nearest = distances[:k]
    k_frequencies: Counter[str] = collections.Counter(
        s.sample.sample.species for s in k_nearest
    )
    mode, fq = k_frequencies.most_common(1)[0]
    return mode


def k_nn_b(
    k: int, dist: DistanceFunc, training_data: TrainingList, unknown: AnySample
) -> str:
    """
    Używamy funkcji bisect.bisect_left() by zachować krótą listę _k_ najbliższych sąsiadów, posortowaną.
    
    1.	Dla każdej próbki uczącej:
    
        1)	Obliczamy odległość pomiędzy tą próbką uczącą oraz nieznaną próbką.
        
        2)	Jeśli jest ona większa od k najbliższych sąsiadów znalezionych do tej pory, to odrzucamy odległość.
        
        3)	W przeciwnym razie znajdujemy pomiędzy k wartościami miejsce w którym należy umieścić nową wartość, wstawiamy ją w to miejsce, po czym skracamy listę tak by ponownie zawierała k elementów.
        
    2.	Znajdujemy częstotliwości wyników pomiędzy k najbliższymi sąsiadami.
    
    3.	Wybieramy dominantę (najwyższą częstotliwość) pomiędzy k najbliższymi sąsiadami.


    >>> data = [
    ...     TrainingKnownSample(KnownSample(sample=Sample(1, 2, 3, 4), species="a")),
    ...     TrainingKnownSample(KnownSample(sample=Sample(2, 3, 4, 5), species="b")),
    ...     TrainingKnownSample(KnownSample(sample=Sample(3, 4, 5, 6), species="c")),
    ...     TrainingKnownSample(KnownSample(sample=Sample(4, 5, 6, 7), species="d")),
    ... ]
    >>> dist = lambda ts1, u2: max(abs(ts1.sample.sample[i] - u2.sample[i]) for i in range(len(ts1)))
    >>> k_nn_b(1, dist, data, UnknownSample(Sample(1.1, 2.1, 3.1, 4.1)))
    'a'
    """
    k_nearest = [
        Measured(float("inf"), cast(TrainingKnownSample, None)) for _ in range(k)
    ]
    for t in training_data:
        t_dist = dist(t, unknown)
        if t_dist > k_nearest[-1].distance:
            continue
        new = Measured(t_dist, t)
        k_nearest.insert(bisect.bisect_left(k_nearest, new), new)
        k_nearest.pop(-1)
    k_frequencies: Counter[str] = collections.Counter(
        s.sample.sample.species for s in k_nearest
    )
    mode, fq = k_frequencies.most_common(1)[0]
    return mode


def k_nn_q(
    k: int, dist: DistanceFunc, training_data: TrainingList, unknown: AnySample
) -> str:
    """
    Używamy modułu heapq do utrzymania listy sąsiadów w kolejności posortowanej,
    unikając przy tym sortowania po wykonaniu wszystkich obliczeń.

    1. Dla każdego elementu:

        1. Obliczamy odległość.

        2. Wstawiamy wyniki do kopca, zachowując kolejność odległości.

    2.  Znajdujemy częstotliwości wartości wynikowych dla _k_ najbliższych sąsiadów.

    3.  Wybieramy dominantę (najwyższą częstotliwość) pomiędzy k najbliższymi sąsiadami.

    >>> data = [
    ...     TrainingKnownSample(KnownSample(sample=Sample(1, 2, 3, 4), species="a")),
    ...     TrainingKnownSample(KnownSample(sample=Sample(2, 3, 4, 5), species="b")),
    ...     TrainingKnownSample(KnownSample(sample=Sample(3, 4, 5, 6), species="c")),
    ...     TrainingKnownSample(KnownSample(sample=Sample(4, 5, 6, 7), species="d")),
    ... ]
    >>> dist = lambda ts1, u2: max(abs(ts1.sample.sample[i] - u2.sample[i]) for i in range(len(ts1)))
    >>> k_nn_q(1, dist, data, UnknownSample(Sample(1.1, 2.1, 3.1, 4.1)))
    'a'
    """
    measured_iter = (Measured(dist(t, unknown), t) for t in training_data)
    k_nearest = heapq.nsmallest(k, measured_iter)
    k_frequencies: Counter[str] = collections.Counter(
        s.sample.sample.species for s in k_nearest
    )
    mode, fq = k_frequencies.most_common(1)[0]
    return mode


Classifier = Callable[[int, DistanceFunc, TrainingList, AnySample], str]


class Hyperparameter(NamedTuple):
    k: int
    distance_function: DistanceFunc
    training_data: TrainingList
    classifier: Classifier

    def classify(self, unknown: AnySample) -> str:
        classifier = self.classifier
        return classifier(self.k, self.distance_function, self.training_data, unknown)

    def test(self, testing: TestingList) -> int:
        classifier = self.classifier
        test_results = (
            ClassifiedKnownSample(
                t.sample,
                classifier(
                    self.k, self.distance_function, self.training_data, t.sample
                ),
            )
            for t in testing
        )
        pass_fail = map(
            lambda t: (1 if t.sample.species == t.classification else 0), test_results
        )
        return sum(pass_fail)


def minkowski(
    s1: TrainingKnownSample,
    s2: AnySample,
    m: int,
    summarize: Callable[[Iterable[float]], float] = sum,
) -> float:
    return (
        summarize(
            [
                abs(s1.sample.sample.sepal_length - s2.sample.sepal_length) ** m,
                abs(s1.sample.sample.sepal_width - s2.sample.sepal_width) ** m,
                abs(s1.sample.sample.petal_length - s2.sample.petal_length) ** m,
                abs(s1.sample.sample.petal_width - s2.sample.petal_width) ** m,
            ]
        )
        ** (1 / m)
    )


def manhattan(s1: TrainingKnownSample, s2: AnySample) -> float:
    return minkowski(s1, s2, m=1)


def euclidean(s1: TrainingKnownSample, s2: AnySample) -> float:
    return minkowski(s1, s2, m=2)


def chebyshev(s1: TrainingKnownSample, s2: AnySample) -> float:
    return minkowski(s1, s2, m=1, summarize=max)


test_hyperparameter = """
>>> data = [
...     KnownSample(sample=Sample(1, 2, 3, 4), species="a"),
...     KnownSample(sample=Sample(2, 3, 4, 5), species="b"),
...     KnownSample(sample=Sample(3, 4, 5, 6), species="c"),
...     KnownSample(sample=Sample(4, 5, 6, 7), species="d"),
... ]
>>> training_data = [TrainingKnownSample(s) for s in data]
>>> h = Hyperparameter(1, manhattan, training_data, k_nn_1)
>>> h.classify(UnknownSample(Sample(2, 3, 4, 5)))
'b'
>>> testing_data = [
...     TestingKnownSample(KnownSample(sample=Sample(1.1, 2.1, 3.1, 4.1), species="a")),
...     TestingKnownSample(KnownSample(sample=Sample(1.2, 2.2, 3.2, 4.3), species="b")),
... ]
>>> h.test(testing_data)/len(testing_data)
0.5
"""

__test__ = {name: case for name, case in globals().items() if name.startswith("test_")}
