import React, {useEffect} from "react";
import ReactFlow, {MarkerType, useEdgesState, useNodesState} from "reactflow";
import 'reactflow/dist/style.css';
import PropTypes from "prop-types";
import {useTheme} from "@mui/system";

function ArchitectureRenderer({layers, height = 400}){

    const theme = useTheme();

    const sortLayers = (layerArray) => {
        return layerArray.sort((a, b) => {
            if (a.type === 'input') return -1;
            if (b.type === 'output') return -1;
            if (a.type === 'output') return 1;
            return 0;
        });
    };

    const generateNodes = () => {
        let nodeId = 0;
        const nodeWidth = 150;
        const nodeHeight = 30;
        const sortedLayers = sortLayers(layers);

        return sortedLayers.flatMap((layer, layerIndex) => {
            const totalHeight = layer.size * nodeHeight;
            const startY = (400 - totalHeight) / 2;

            const nodes = [];
            for (let i = 0; i < layer.size; i++) {
                nodes.push({
                    id: `node-${nodeId}`,
                    position: { x: layerIndex * nodeWidth + 100, y: startY + (i * nodeHeight) },
                    data: { label: '' },
                    type: 'default',
                    draggable: false,
                    connectable: true,  // Keep this true for custom handles
                    style: {
                        width: 20,
                        height: 20,
                        borderRadius: '50%',
                        backgroundColor: theme.palette.primary.main,
                    },
                    sourcePosition: layer.type === 'hidden' || layer.type === 'input' ? 'right' : 'left',
                    targetPosition: layer.type === 'hidden' || layer.type === 'output' ? 'left' : 'right',
                });
                nodeId++;
            }
            return nodes;
        });
    };

    const generateEdges = () => {
        const edges = [];
        let currentNode = 0;
        const sortedLayers = sortLayers(layers);  // Ensure correct order for connections

        for (let i = 1; i < sortedLayers.length; i++) {
            const previousLayerSize = parseInt(sortedLayers[i - 1].size);
            const currentLayerSize = parseInt(sortedLayers[i].size);

            for (let j = 0; j < previousLayerSize; j++) {
                for (let k = 0; k < currentLayerSize; k++) {
                    const targetNumber = parseInt(currentNode) + parseInt(previousLayerSize) + parseInt(k);
                    const sourceNumber = parseInt(currentNode) + parseInt(j);
                    edges.push({
                        id: `edge-${currentNode}-${j}-${k}`,
                        source: `node-${sourceNumber}`,
                        target: `node-${targetNumber}`,
                        type: 'straight',
                        markerEnd: {},
                        style: { strokeWidth: 1, stroke: theme.palette.primary.main },
                    });
                }
            }
            currentNode += previousLayerSize;
        }
        return edges;
    };

    const [nodes, setNodes] = useNodesState(generateNodes());
    const [edges, setEdges] = useEdgesState(generateEdges());

    useEffect(() => {
        setNodes(generateNodes());
        setEdges(generateEdges());
    }, [layers]);

    return (
        <div style={{height: height}}>
            <ReactFlow
                nodes={nodes}
                edges={edges}
                fitView
                nodeTypes={{}}
            />
        </div>
    );
}

ArchitectureRenderer.propTypes = {
    layers: PropTypes.array.isRequired,
    height: PropTypes.number,
}

export default ArchitectureRenderer;