# =======================================================================
# KAIRA v1.2 MCTS Motoru (D18 "BÜYÜK BEYİN" UYUMLU)
# =======================================================================
#
# GÜNCELLEME v1.2 (D18):
# 1. "_board_to_tensor" fonksiyonu, KAIRA v0.8 (Büyük Beyin)
#    eğitim koduna UYGUN şekilde (8, 8, 18) tensör üretecek
#    şekilde DÜZELTİLDİ.
# 2. 6 kural katmanı (Sıra, Rok, En Passant) eklendi.
# 3. v1.1'deki "tertemiz" WDL (Değer) düzeltmesi korundu.
#
# =======================================================================

import numpy as np
import tensorflow as tf
import chess
import time
import math
import sys

# --- (Softmax ve Harita Yükleme v1.1 ile AYNI) ---
def _softmax(x: np.ndarray) -> np.ndarray:
    a = np.asarray(x, dtype=np.float64)
    if a.size == 0: return a
    a = a - np.max(a)
    ex = np.exp(a)
    s = np.sum(ex)
    if s == 0: return ex
    return (ex / s).astype(np.float32)

try:
    # Bu dosyanın, motorun çalıştığı yerde olması gerekir
    from kaira_harita_eski import policy_index
    print("✅ Harita ('kaira_harita_eski.py') başarıyla bulundu ve yüklendi.")
    INDEX_TO_MOVE_UCI = policy_index
    MOVE_UCI_TO_INDEX = {move_uci: index for index, move_uci in enumerate(policy_index)}
except ImportError:
    print("❌ HATA: 'kaira_harita.py' dosyası bu klasörde bulunamadı!"); sys.exit()

FLIPPED_INDEX_MAP = {}
try:
    for idx, uci in enumerate(INDEX_TO_MOVE_UCI):
        try: mv = chess.Move.from_uci(uci)
        except Exception: continue
        try:
            f_from = chess.square_mirror(mv.from_square); f_to = chess.square_mirror(mv.to_square); prom = ''
            if mv.promotion is not None: prom = chess.piece_symbol(mv.promotion).lower()
            f_uci = chess.square_name(f_from) + chess.square_name(f_to) + prom
            flipped_idx = MOVE_UCI_TO_INDEX.get(f_uci)
            if flipped_idx is not None: FLIPPED_INDEX_MAP[idx] = flipped_idx
        except Exception: continue
except Exception: FLIPPED_INDEX_MAP = {}
# --- (Harita Yükleme Bitti) ---


# =======================================================================
# --- 2. KAIRA'NIN BEYNİ (AĞ KATMANI) (D18 GÜNCELLEMELİ) ---
# =======================================================================
class KAIRA_Network:
    """
    Bu sınıf, dondurulmuş .pb modelini yükler ve bir TensorFlow
    oturumunda (session) çalıştırır. (v1.2 D18 UYUMLU)
    """
    def __init__(self, model_path):
        print(f"KAIRA Beyni (.pb Modeli) yükleniyor: {model_path}")
        try:
            devices = tf.config.list_physical_devices()
            print(f"TensorFlow cihazları: {devices}")
            self.graph = tf.Graph()
            config = tf.compat.v1.ConfigProto()
            config.gpu_options.allow_growth = True
            self.session = tf.compat.v1.Session(graph=self.graph, config=config)
        except Exception as e: print(f"❌ HATA: TensorFlow oturumu başlatılamadı: {e}"); sys.exit()
        try:
            with self.graph.as_default():
                graph_def = tf.compat.v1.GraphDef()
                with tf.compat.v1.gfile.GFile(model_path, 'rb') as f:
                    graph_def.ParseFromString(f.read())
                tf.import_graph_def(graph_def, name="")
                print("✅ .pb grafiği başarıyla içe aktarıldı.")
                
                # ÖNEMLİ: Dondurucu (freezer) betiğinizdeki GİRİŞ adıyla eşleşmeli
                self.input_tensor = self.graph.get_tensor_by_name("board_input:0")
                
                # Dondurucu (freezer) betiğinizdeki ÇIKIŞ adlarıyla eşleşmeli
                # (Genellikle "Identity:0" ve "Identity_1:0" olur)
                self.output_policy = self.graph.get_tensor_by_name("Identity:0") 
                self.output_value = self.graph.get_tensor_by_name("Identity_1:0")
                
                print("✅ Giriş/Çıkış node'ları başarıyla bağlandı.")
        except Exception as e: 
            print(f"❌ HATA: .pb modeli yüklenemedi veya node'lar bulunamadı: {e}")
            print("İPUCU: Dondurucu (convert_to_pb.py) betiğinin çıktısındaki node adlarını kontrol edin.")
            sys.exit()

        # Parça haritası (v1.1 ile aynı)
        self.piece_map = {'P': 0,'N': 1,'B': 2,'R': 3,'Q': 4,'K': 5, 'p': 6,'n': 7,'b': 8,'r': 9,'q': 10,'k': 11}
        self.eval_cache = {}

    def _board_to_tensor(self, board: chess.Board):
        # ==========================================================
        # "TERTEMİZ" v1.2 D18 DÜZELTMESİ BURADA
        # ==========================================================
        # KAIRA v0.8 "Büyük Beyin" (D18) eğitimine uygun tensör üret
        # Artık (1, 8, 8, 18) üretiyoruz
        
        tensor = np.zeros((1, 8, 8, 18), dtype=np.float32)

        # 1. DÜZLEMLER 0-11: Parçalar (v1.1 ile aynı)
        for square in chess.SQUARES:
            piece = board.piece_at(square)
            if piece:
                rank = chess.square_rank(square)
                file = chess.square_file(square)
                tensor_r = 7 - rank  # Tahtayı NN için çevir (rank 8 -> 0)
                tensor_c = file
                tensor[0, tensor_r, tensor_c, self.piece_map[piece.symbol()]] = 1

        # --- YENİ EKLENEN 6 KURAL KATMANI (D18) ---

        # 2. DÜZLEM 12: Sıra Kimde (1.0 = Beyaz, 0.0 = Siyah)
        turn_value = 1.0 if board.turn == chess.WHITE else 0.0
        tensor[0, :, :, 12] = turn_value # Tüm düzlemi doldur

        # 3. DÜZLEM 13: Beyaz Kısa Rok (K)
        if board.has_kingside_castling_rights(chess.WHITE):
            tensor[0, :, :, 13] = 1.0

        # 4. DÜZLEM 14: Beyaz Uzun Rok (Q)
        if board.has_queenside_castling_rights(chess.WHITE):
            tensor[0, :, :, 14] = 1.0

        # 5. DÜZLEM 15: Siyah Kısa Rok (k)
        if board.has_kingside_castling_rights(chess.BLACK):
            tensor[0, :, :, 15] = 1.0

        # 6. DÜZLEM 16: Siyah Uzun Rok (q)
        if board.has_queenside_castling_rights(chess.BLACK):
            tensor[0, :, :, 16] = 1.0

        # 7. DÜZLEM 17: En Passant Karesi
        ep_square = board.ep_square
        if ep_square is not None:
            rank = chess.square_rank(ep_square)
            file = chess.square_file(ep_square)
            tensor_r = 7 - rank # Rank'ı çevir
            tensor_c = file
            tensor[0, tensor_r, tensor_c, 17] = 1.0

        # ==========================================================
        # DÜZELTME BİTTİ
        # ==========================================================

        return tensor

    def _flip_policy(self, policy):
        # ... (Bu fonksiyon v1.1 ile aynı, değişiklik yok) ...
        EXPECTED_LEN = len(INDEX_TO_MOVE_UCI)
        policy_arr = np.asarray(policy).ravel()
        if policy_arr.size != EXPECTED_LEN: print(f"[WARN] _flip_policy: policy uzunluğu beklenenden farklı: {policy_arr.size} (beklenen {EXPECTED_LEN})")
        flipped_policy = np.zeros(EXPECTED_LEN, dtype=policy_arr.dtype)
        if FLIPPED_INDEX_MAP:
            for idx, prob in enumerate(policy_arr):
                if prob == 0: continue
                flipped_idx = FLIPPED_INDEX_MAP.get(idx)
                if flipped_idx is not None and flipped_idx < EXPECTED_LEN: flipped_policy[flipped_idx] = prob
            return flipped_policy
        for index, prob in enumerate(policy_arr):
            if prob == 0: continue;
            if index >= EXPECTED_LEN: continue
            try: uci_move = INDEX_TO_MOVE_UCI[index]
            except Exception: continue
            try: move = chess.Move.from_uci(uci_move)
            except Exception: continue
            try: flipped_from = chess.square_mirror(move.from_square); flipped_to = chess.square_mirror(move.to_square)
            except Exception: continue
            try:
                from_sq = chess.square_name(flipped_from); to_sq = chess.square_name(flipped_to); prom = '';
                if move.promotion is not None:
                    if move.promotion in (chess.QUEEN, chess.ROOK, chess.BISHOP, chess.KNIGHT): prom = chess.piece_symbol(move.promotion).lower()
                    else: continue
                flipped_uci = from_sq + to_sq + prom
            except Exception: continue
            flipped_index = MOVE_UCI_TO_INDEX.get(flipped_uci)
            if flipped_index is None: continue
            if flipped_index >= EXPECTED_LEN: continue
            flipped_policy[flipped_index] = prob
        return flipped_policy

    def evaluate(self, board: chess.Board):
        """
        GÜNCELLENMİŞ v1.1 - "Tertemiz" WDL Düzeltmesi (Korundu)
        """
        # --- (Önbellekleme ve Çevirme - v1.1 ile aynı) ---
        is_black_turn = (board.turn == chess.BLACK)
        eval_board = board.mirror() if is_black_turn else board
        
        # D18 UYUMLU ÖNBELLEKLEME (Cache Key Fix)
        # Önceki versiyonda sadece taş yerleşimi (fen.split(' ')[0]) alınıyordu.
        # D18 modeli rok ve en-passant haklarına duyarlı olduğu için artık tam FEN kullanılmalı.
        fen_key = eval_board.fen() 
        
        if fen_key in self.eval_cache:
            cached_policy, cached_value = self.eval_cache[fen_key]
            if is_black_turn: return self._flip_policy(cached_policy), -cached_value
            else: return cached_policy, cached_value
        # --- (Önbellekleme Bitti) ---

        # 1. Tensörü Hazırla (Artık D18 üretiyor)
        tensor = self._board_to_tensor(eval_board)

        # 2. MODELİ ÇALIŞTIR (v1.1 ile aynı)
        try:
            policy_output, value_logits = self.session.run(
                [self.output_policy, self.output_value],
                feed_dict={
                    self.input_tensor: tensor # <--- Artık (1, 8, 8, 18) besliyor
                }
            )
            model_policy = policy_output.squeeze()
            model_value_head = value_logits.squeeze()
            
        except Exception as e:
            print(f"❌ HATA: Model çalıştırılırken (session.run) hata alındı: {e}")
            return np.zeros(len(INDEX_TO_MOVE_UCI)), 0.0

        # 3. DEĞER HESABI (v1.1 WDL Düzeltmesi - KORUNDU)
        try:
            value_probs = _softmax(model_value_head) # [W, D, L]
        except Exception:
            value_probs = model_value_head / (np.sum(model_value_head) + 1e-12)
        
        if value_probs.size >= 3:
            # "Tertemiz" MCTS Değeri = Kazanma Olasılığı - Kaybetme Olasılığı
            model_value = float(value_probs[0] - value_probs[2]) # (W - L)
        else:
            model_value = float(value_probs.ravel()[0]) # Fallback

        model_value = max(-1.0, min(1.0, model_value))
        
        # 4. Önbelleğe Al ve Döndür (v1.1 ile aynı)
        self.eval_cache[fen_key] = (model_policy, model_value)

        if is_black_turn:
            final_policy = self._flip_policy(model_policy)
            final_value = -model_value
            return final_policy, final_value
        else:
            return model_policy, model_value

# =======================================================================
# --- 3. HESAPLAMA MOTORU (MCTS) ---
# =======================================================================
# (MCTS_Node ve KAIRA_Engine sınıfları v1.1 ile BİREBİR AYNIDIR)
# (Değişiklik yapmaya gerek yok, onlar sadece KAIRA_Network'ü kullanır)
# =======================================================================
class MCTS_Node:
    def __init__(self, parent=None, move=None, prior_prob=0):
        self.parent = parent; self.move = move; self.children = []
        self.n_visits = 0; self.total_value = 0.0; self.prior_prob = prior_prob
    @property
    def q_value(self):
        if self.n_visits == 0: return 0.0
        return self.total_value / self.n_visits
    def is_leaf(self): return len(self.children) == 0
    def select_child(self, c_puct=1.5):
        best_score = -float('inf'); best_child = None
        parent_sqrt_visits = math.sqrt(self.n_visits)
        for child in self.children:
            q_value_for_parent = -child.q_value
            u_value = c_puct * child.prior_prob * (parent_sqrt_visits / (1 + child.n_visits))
            score = q_value_for_parent + u_value
            if score > best_score:
                best_score = score; best_child = child
        return best_child
    def expand_node(self, board, policy_probs):
        policy_arr = np.asarray(policy_probs).ravel(); EXPECTED_LEN = len(INDEX_TO_MOVE_UCI)
        if policy_arr.size != EXPECTED_LEN: print(f"[WARN] expand_node: policy_probs uzunluğu beklenenden farklı: {policy_arr.size} (beklenen {EXPECTED_LEN}). Eksik indeksler 0.0 olarak alınacak.")
        for move in board.legal_moves:
            move_uci = move.uci(); policy_index = MOVE_UCI_TO_INDEX.get(move_uci); prob = 0.0
            if policy_index is not None:
                if policy_index < policy_arr.size: prob = float(policy_arr[policy_index])
                else: print(f"[WARN] expand_node: policy_index {policy_index} >= policy_probs.size ({policy_arr.size}) for move {move_uci}")
            child_node = MCTS_Node(parent=self, move=move, prior_prob=prob)
            self.children.append(child_node)
    def backpropagate(self, value):
        node = self
        while node is not None:
            node.n_visits += 1
            node.total_value += value if node.parent is None else -value
            value = -value; node = node.parent

class KAIRA_Engine:
    def __init__(self, model_path):
        self.network = KAIRA_Network(model_path) # <-- Artık v1.2 (D18) kullanıyor
        self.expected_policy_len = len(MOVE_UCI_TO_INDEX)
        print(f"✅ KAIRA MCTS Motoru v1.2 (D18) {self.expected_policy_len} hamlelik harita ile hazır.")

    def get_nn_top_moves(self, board, k=5):
        policy, _ = self.network.evaluate(board)
        policy_arr = np.asarray(policy).ravel()
        if policy_arr.size != self.expected_policy_len: print(f"[WARN] get_nn_top_moves: policy uzunluğu beklenenden farklı: {policy_arr.size} (beklenen {self.expected_policy_len})")
        moves_with_probs = []
        for move in board.legal_moves:
            policy_index = MOVE_UCI_TO_INDEX.get(move.uci())
            prob = 0.0
            if policy_index is not None and policy_index < policy_arr.size: prob = float(policy_arr[policy_index])
            moves_with_probs.append((move, prob))
        moves_with_probs.sort(key=lambda x: x[1], reverse=True)
        return moves_with_probs[:k]

    def choose_move(self, board, mcts_sims=400, deterministic=True,
                    add_noise=False, dirichlet_alpha=0.3, noise_fraction=0.25):
        """
        Dirichlet gürültüsü ekleme seçeneği içerir.
        (best_move, policy_target) döndürür.
        """
        
        self.network.eval_cache.clear()
        
        policy_target = np.zeros(self.expected_policy_len, dtype=np.float32)
        
        if board.is_game_over():
            print("Oyun bitti."); return None, policy_target

        num_simulations = mcts_sims
        start_time = time.time()
        
        root_node = MCTS_Node()
        policy, value = self.network.evaluate(board)
        root_node.expand_node(board, policy)
        root_node.value_pred = value

        if add_noise and root_node.children:
            try:
                noise = np.random.dirichlet([dirichlet_alpha] * len(root_node.children))
                for i, child in enumerate(root_node.children):
                    child.prior_prob = (1 - noise_fraction) * child.prior_prob + \
                                       noise_fraction * noise[i]
            except Exception as e:
                print(f"❌ HATA: Dirichlet gürültüsü uygulanırken hata: {e}")

        if num_simulations == 0:
            if not root_node.children: return None, policy_target
            best_move = max(root_node.children, key=lambda c: c.prior_prob).move
            print("MCTS 0 simülasyon: Sadece sezgisel hamle oynandı.")
            return best_move, policy_target
            
        for _ in range(num_simulations):
            node = root_node; sim_board = board.copy()
            while not node.is_leaf():
                node = node.select_child(); sim_board.push(node.move)
            
            if sim_board.is_game_over(claim_draw=True): 
                value = 0.0 
                if sim_board.is_checkmate():
                    value = -1.0
            else:
                policy, value = self.network.evaluate(sim_board)
                node.expand_node(sim_board, policy)
            node.backpropagate(value)

        if not root_node.children:
            print("Hata: Kök düğümün hiç çocuğu yok."); return None, policy_target

        total_visits = sum(c.n_visits for c in root_node.children)
        if total_visits > 0:
            for child in root_node.children:
                move_uci = child.move.uci()
                policy_index = MOVE_UCI_TO_INDEX.get(move_uci)
                if policy_index is not None:
                    policy_target[policy_index] = child.n_visits / total_visits

        if deterministic:
            best_child = max(root_node.children, key=lambda c: c.n_visits)
        else:
            visit_counts = np.array([c.n_visits for c in root_node.children])
            visit_probs = visit_counts / (np.sum(visit_counts) + 1e-6)
            if np.sum(visit_probs) == 0:
                best_child = np.random.choice(root_node.children)
            else:
                best_child = np.random.choice(root_node.children, p=visit_probs)
            
        best_move = best_child.move
        
        end_time = time.time()
        nps = num_simulations / (end_time - start_time + 0.001)
        
        if (nps < 5000):
             print(f"MCTS tamamlandı ({nps:.1f} sim/saniye).")
             print(f"  > Seçilen Hamle: {best_move.uci()} (Ziyaret: {best_child.n_visits})")
        
        return best_move, policy_target

if __name__ == '__main__':
    print('KAIRA motoru modülü yüklendi (v1.2 - D18 "BÜYÜK BEYİN" UYUMLU).')
