from numbers import Integral as Integral, Real as Real
from typing import Any, ClassVar, Literal
from typing_extensions import Self

import numpy as np
import scipy.sparse as sp
from joblib import effective_n_jobs as effective_n_jobs
from numpy import ndarray
from numpy.random import RandomState
from scipy.special import gammaln as gammaln, logsumexp as logsumexp

from .._typing import ArrayLike, Float, Int, MatrixLike
from ..base import BaseEstimator, ClassNamePrefixFeaturesOutMixin, TransformerMixin
from ..utils import check_random_state as check_random_state, gen_batches as gen_batches, gen_even_slices as gen_even_slices
from ..utils._param_validation import Interval as Interval, StrOptions as StrOptions
from ..utils.parallel import Parallel as Parallel, delayed as delayed
from ..utils.validation import check_is_fitted as check_is_fitted, check_non_negative as check_non_negative
from ._online_lda_fast import mean_change as cy_mean_change

EPS = ...

class LatentDirichletAllocation(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
    topic_word_prior_: float = ...
    random_state_: RandomState = ...
    doc_topic_prior_: float = ...
    bound_: float = ...
    n_iter_: int = ...
    feature_names_in_: ndarray = ...
    n_features_in_: int = ...
    n_batch_iter_: int = ...
    exp_dirichlet_component_: ndarray = ...
    components_: ndarray = ...

    _parameter_constraints: ClassVar[dict] = ...

    def __init__(
        self,
        n_components: Int = 10,
        *,
        doc_topic_prior: None | Float = None,
        topic_word_prior: None | Float = None,
        learning_method: Literal["batch", "online"] = "batch",
        learning_decay: Float = 0.7,
        learning_offset: Float = 10.0,
        max_iter: Int = 10,
        batch_size: Int = 128,
        evaluate_every: Int = ...,
        total_samples: float | Int = 1e6,
        perp_tol: Float = 1e-1,
        mean_change_tol: Float = 1e-3,
        max_doc_update_iter: Int = 100,
        n_jobs: None | Int = None,
        verbose: Int = 0,
        random_state: RandomState | None | Int = None,
    ) -> None: ...
    def partial_fit(self, X: MatrixLike | ArrayLike, y: Any = None) -> Self: ...
    def fit(self, X: MatrixLike | ArrayLike, y: Any = None) -> Self: ...
    def transform(self, X: MatrixLike | ArrayLike) -> ndarray: ...
    def score(self, X: MatrixLike | ArrayLike, y: Any = None) -> float: ...
    def perplexity(self, X: MatrixLike | ArrayLike, sub_sampling: bool = False) -> float: ...
