import re
import binascii
from collections import defaultdict, Counter

# Fonction pour lire le fichier binaire et convertir le contenu en hexadécimal
def read_binary_file(file_path):
    with open(file_path, 'rb') as file:
        file_content = file.read()
    return binascii.hexlify(file_content).decode('utf-8')

# Fonction pour lire le fichier de log et extraire son contenu
def read_log_file(file_path):
    with open(file_path, 'r') as file:
        log_content = file.readlines()
    return log_content

# Fonction pour extraire les segments du fichier de log
def extract_segments(log_content):
    segments = []
    segment = []
    for line in log_content:
        if line.startswith("Index:"):
            if segment:
                segments.append(segment)
            segment = [line]
        else:
            segment.append(line)
    if segment:
        segments.append(segment)
    return segments

# Fonction pour extraire les octets inversés de chaque segment
def extract_inverted_bytes(segment):
    inverted_bytes = []
    for line in segment:
        match = re.search(r'Inverted Byte: (\d+)', line)
        if match:
            inverted_byte = int(match.group(1))
            inverted_bytes.append(inverted_byte)
    return inverted_bytes

# Fonction pour extraire les octets inversés de chaque segment et leurs positions
def extract_inverted_bytes_with_positions(segment, index, mcuid_length):
    inverted_bytes = []
    positions = []
    for line in segment:
        match = re.search(r'Inverted Byte: (\d+)', line)
        if match:
            inverted_byte = int(match.group(1))
            position = index * mcuid_length + len(inverted_bytes)
            inverted_bytes.append(inverted_byte)
            positions.append(position)
    return inverted_bytes, positions

# Fonction pour extraire les segments du contenu du fichier binaire
def extract_binary_segments(file_content, segment_size):
    num_segments = len(file_content) // (segment_size * 2)
    segments = []

    for i in range(num_segments):
        segment = file_content[i * segment_size * 2: (i + 1) * segment_size * 2]
        if len(segment) == segment_size * 2:  # Assurez-vous que le segment a la bonne longueur
            segments.append((i, segment))
    
    return segments

# Fonction pour appliquer des transformations spécifiques sur le fichier binaire
def apply_specific_transformations(segment, mcuid_bytes):
    transformations = []
    for i in range(len(mcuid_bytes)):
        original_byte = int(segment[i * 2:(i + 1) * 2], 16)
        mcuid_byte = mcuid_bytes[i]
        results = {
            "addition_mod": (original_byte + mcuid_byte) % 256,
            "xor": original_byte ^ mcuid_byte,
            "subtraction_mod": (original_byte - mcuid_byte) % 256
        }
        transformations.append(results)
        if mcuid_byte in results.values():
            return True, transformations
    return False, transformations

# Fonction pour appliquer des transformations inversées
def invert_transformations(segment, transformations, mcuid_bytes):
    inverted_bytes = []
    for i, trans in enumerate(transformations):
        original_byte = int(segment[i * 2:(i + 1) * 2], 16)
        mcuid_byte = None
        if trans["addition_mod"] == (original_byte + mcuid_bytes[i]) % 256:
            mcuid_byte = (trans["addition_mod"] - original_byte) % 256
        elif trans["xor"] == (original_byte ^ mcuid_bytes[i]):
            mcuid_byte = original_byte ^ trans["xor"]
        elif trans["subtraction_mod"] == (original_byte - mcuid_bytes[i]) % 256:
            mcuid_byte = (trans["subtraction_mod"] + original_byte) % 256
        inverted_bytes.append(mcuid_byte)
    return inverted_bytes

# Fonction principale pour trouver le MCUID dans un fichier binaire en utilisant le fichier de log
def find_mcuid(binary_file_path, log_file_path):
    mcuid_length = 16  # MCUID length in bytes
    
    log_content = read_log_file(log_file_path)
    log_segments = extract_segments(log_content)

    # Debug print statements
    print(f"Number of log segments: {len(log_segments)}")
    
    # Lire le fichier binaire
    file_content = read_binary_file(binary_file_path)
    binary_segments = extract_binary_segments(file_content, mcuid_length)
    
    # Debug print statements
    print(f"Number of binary segments: {len(binary_segments)}")
    
    combined_bytes = defaultdict(list)
    for index, segment in binary_segments:
        print(f"Processing binary segment index: {index}")  # Debug print statement
        for log_segment in log_segments:
            log_index = int(re.search(r'Index: (\d+)', log_segment[0]).group(1))
            if log_index == index:
                mcuid_bytes = bytes.fromhex(''.join([f'{byte:02X}' for byte in extract_inverted_bytes(log_segment)]))
                match, transformations = apply_specific_transformations(segment, mcuid_bytes)
                if match:
                    inverted_bytes = invert_transformations(segment, transformations, mcuid_bytes)
                    for i, byte in enumerate(inverted_bytes):
                        position = index * len(inverted_bytes) + i
                        combined_bytes[position].append(byte)
    
    # Debug print statements
    print("Combined Bytes:", combined_bytes)
    
    # Trier les octets combinés par position et sélectionner les octets les plus fréquents
    sorted_indices = sorted(combined_bytes.keys())
    sorted_bytes = [Counter(combined_bytes[i]).most_common(1)[0][0] for i in sorted_indices]
    
    # Debug print statements
    print("Sorted Bytes:", sorted_bytes)
    
    # Extraire les MCUIDs possibles et vérifier celui qui est correct
    possible_mcuid_hex = ''.join(f'{byte:02X}' for byte in sorted_bytes[:mcuid_length])
    return possible_mcuid_hex

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Find MCUID in a binary file using matching_segments.log")
    parser.add_argument("binary_file_path", type=str, help="Path to the binary file to analyze")
    parser.add_argument("log_file_path", type=str, help="Path to the log file to analyze")
    args = parser.parse_args()

    # Exemple d'utilisation
    binary_file_path = args.binary_file_path
    log_file_path = args.log_file_path
    mcuid = find_mcuid(binary_file_path, log_file_path)
    print(f"Extracted MCUID: {mcuid}")
