const { useState, useMemo, useCallback } = React; // Project 3D point to 2D screen const project = (point, rotX, rotY, scale = 100, offset = [200, 200]) => { let [x, y, z] = point; const cosY = Math.cos(rotY), sinY = Math.sin(rotY); [x, z] = [x * cosY - z * sinY, x * sinY + z * cosY]; const cosX = Math.cos(rotX), sinX = Math.sin(rotX); [y, z] = [y * cosX - z * sinX, y * sinX + z * cosX]; return [offset[0] + x * scale, offset[1] - y * scale, z]; }; const relu = (x) => Math.max(0, x); const neuronActivation = (point, w, b) => { const preActivation = w[0] * point[0] + w[1] * point[1] + b; return { pre: preActivation, post: relu(preActivation) }; }; // Dataset generators const datasets = { xor: { name: 'XOR Pattern', description: 'Classic XOR: opposite corners share class', generate: () => { const points = []; const class0Centers = [[-0.7, 0.7], [0.7, -0.7]]; const class1Centers = [[0.7, 0.7], [-0.7, -0.7]]; class0Centers.forEach(center => { for (let i = 0; i < 5; i++) { const angle = (i / 5) * Math.PI * 2 + Math.random() * 0.5; const r = 0.1 + Math.random() * 0.12; points.push({ pos: [center[0] + Math.cos(angle) * r, center[1] + Math.sin(angle) * r], class: 0 }); } }); class1Centers.forEach(center => { for (let i = 0; i < 5; i++) { const angle = (i / 5) * Math.PI * 2 + Math.random() * 0.5; const r = 0.1 + Math.random() * 0.12; points.push({ pos: [center[0] + Math.cos(angle) * r, center[1] + Math.sin(angle) * r], class: 1 }); } }); return points; }, recommendedNeurons: [ { w: [1, 1], b: 0 }, { w: [1, -1], b: 0 }, { w: [0, 1], b: 0 } ] }, circles: { name: 'Concentric Circles', description: 'Inner circle vs outer ring', generate: () => { const points = []; // Inner circle - class 0 for (let i = 0; i < 12; i++) { const angle = (i / 12) * Math.PI * 2 + Math.random() * 0.3; const r = 0.2 + Math.random() * 0.15; points.push({ pos: [Math.cos(angle) * r, Math.sin(angle) * r], class: 0 }); } // Outer ring - class 1 for (let i = 0; i < 14; i++) { const angle = (i / 14) * Math.PI * 2 + Math.random() * 0.2; const r = 0.7 + Math.random() * 0.2; points.push({ pos: [Math.cos(angle) * r, Math.sin(angle) * r], class: 1 }); } return points; }, recommendedNeurons: [ { w: [1, 0], b: 0.4 }, { w: [-1, 0], b: 0.4 }, { w: [0, 1], b: 0.4 } ] }, spiral: { name: 'Two Spirals', description: 'Interleaved spiral arms', generate: () => { const points = []; const turns = 1.5; for (let i = 0; i < 15; i++) { const t = i / 15; const angle = t * turns * Math.PI * 2; const r = 0.2 + t * 0.6; const noise = () => (Math.random() - 0.5) * 0.1; // Spiral 1 points.push({ pos: [Math.cos(angle) * r + noise(), Math.sin(angle) * r + noise()], class: 0 }); // Spiral 2 (offset by PI) points.push({ pos: [Math.cos(angle + Math.PI) * r + noise(), Math.sin(angle + Math.PI) * r + noise()], class: 1 }); } return points; }, recommendedNeurons: [ { w: [1, 0.5], b: 0.2 }, { w: [-0.5, 1], b: 0.2 }, { w: [1, -1], b: 0 } ] }, moons: { name: 'Two Moons', description: 'Interlocking crescent shapes', generate: () => { const points = []; // Upper moon - class 0 for (let i = 0; i < 12; i++) { const angle = (i / 12) * Math.PI; const r = 0.5 + (Math.random() - 0.5) * 0.15; points.push({ pos: [Math.cos(angle) * r, Math.sin(angle) * r], class: 0 }); } // Lower moon - class 1 (shifted and flipped) for (let i = 0; i < 12; i++) { const angle = Math.PI + (i / 12) * Math.PI; const r = 0.5 + (Math.random() - 0.5) * 0.15; points.push({ pos: [Math.cos(angle) * r + 0.5, Math.sin(angle) * r + 0.3], class: 1 }); } return points; }, recommendedNeurons: [ { w: [1, 1], b: 0.3 }, { w: [1, -1], b: 0.3 }, { w: [-1, 0], b: 0.3 } ] }, checkerboard: { name: 'Checkerboard', description: '2x2 alternating grid pattern', generate: () => { const points = []; const centers = [ { pos: [-0.5, 0.5], class: 0 }, { pos: [0.5, 0.5], class: 1 }, { pos: [-0.5, -0.5], class: 1 }, { pos: [0.5, -0.5], class: 0 } ]; centers.forEach(({ pos, class: cls }) => { for (let i = 0; i < 5; i++) { const noise = () => (Math.random() - 0.5) * 0.3; points.push({ pos: [pos[0] + noise(), pos[1] + noise()], class: cls }); } }); return points; }, recommendedNeurons: [ { w: [1, 0], b: 0 }, { w: [0, 1], b: 0 }, { w: [1, 1], b: 0 } ] }, stripe: { name: 'Middle Stripe', description: 'Class 1 forms a diagonal band', generate: () => { const points = []; for (let i = 0; i < 30; i++) { const x = (Math.random() - 0.5) * 2; const y = (Math.random() - 0.5) * 2; const inStripe = Math.abs(x + y) < 0.5; points.push({ pos: [x, y], class: inStripe ? 1 : 0 }); } return points; }, recommendedNeurons: [ { w: [1, 1], b: 0.5 }, { w: [-1, -1], b: 0.5 }, { w: [1, -1], b: 0 } ] } }; function NNGeometryVisualizer() { const [foldRotX, setFoldRotX] = useState(0.4); const [foldRotY, setFoldRotY] = useState(-0.3); const [hiddenRotX, setHiddenRotX] = useState(0.5); const [hiddenRotY, setHiddenRotY] = useState(0.4); const [dragging, setDragging] = useState(null); const [lastMouse, setLastMouse] = useState([0, 0]); const [liftAmount, setLiftAmount] = useState(0); const [showGrid, setShowGrid] = useState(true); const [showDecisionPlane, setShowDecisionPlane] = useState(true); const [activeNeuron, setActiveNeuron] = useState(null); const [highlightedPoint, setHighlightedPoint] = useState(null); const [selectedDataset, setSelectedDataset] = useState('xor'); const [neuronCount, setNeuronCount] = useState(3); const [neurons, setNeurons] = useState(datasets.xor.recommendedNeurons); const [dataKey, setDataKey] = useState(0); // For regenerating data // Output weights adapt to neuron count const outputWeights = useMemo(() => { if (neuronCount === 1) return [1]; if (neuronCount === 2) return [1, -1]; return [0.8, 0.8, 1.2]; }, [neuronCount]); const outputBias = -0.5; // Generate data points const dataPoints = useMemo(() => { return datasets[selectedDataset].generate(); }, [selectedDataset, dataKey]); // Active neurons (limited by count) const activeNeurons = useMemo(() => neurons.slice(0, neuronCount), [neurons, neuronCount]); // Compute data in all spaces const processedData = useMemo(() => { return dataPoints.map((d, idx) => { const activations = activeNeurons.map(n => neuronActivation(d.pos, n.w, n.b)); const h = activations.map(a => a.post); // Pad to 3D for visualization while (h.length < 3) h.push(0); const foldedZ = (h[0] * 0.4 + h[1] * 0.4 + h[2] * 0.3) * liftAmount; const folded = [d.pos[0], d.pos[1], foldedZ]; const hidden = [h[0], h[1], h[2]]; const output = outputWeights.reduce((sum, w, i) => sum + w * (activations[i]?.post || 0), 0) + outputBias; return { ...d, idx, activations, h, folded, hidden, output }; }); }, [dataPoints, activeNeurons, liftAmount, outputWeights]); // Grid for folded view const gridPoints = useMemo(() => { const points = []; const res = 20, range = 1.4; for (let i = 0; i <= res; i++) { for (let j = 0; j <= res; j++) { const x = (i / res - 0.5) * 2 * range; const y = (j / res - 0.5) * 2 * range; const h = activeNeurons.map(n => relu(n.w[0] * x + n.w[1] * y + n.b)); while (h.length < 3) h.push(0); const z = (h[0] * 0.4 + h[1] * 0.4 + h[2] * 0.3) * liftAmount; points.push({ pos: [x, y], lifted: [x, y, z], h }); } } return points; }, [activeNeurons, liftAmount]); // Neuron lines const neuronLines = useMemo(() => { return activeNeurons.map(({ w, b }) => { const range = 2; if (Math.abs(w[1]) > Math.abs(w[0])) { return [[-range, (-w[0] * -range - b) / w[1]], [range, (-w[0] * range - b) / w[1]]]; } else if (Math.abs(w[0]) > 0.01) { return [[(-w[1] * -range - b) / w[0], -range], [(-w[1] * range - b) / w[0], range]]; } return [[-range, -b / (w[1] || 0.01)], [range, -b / (w[1] || 0.01)]]; }); }, [activeNeurons]); // Decision boundary in 2D const decisionBoundary2D = useMemo(() => { const points = []; const res = 35, range = 1.4; for (let i = 0; i < res; i++) { for (let j = 0; j < res; j++) { const corners = [ [(i/res - 0.5) * 2 * range, (j/res - 0.5) * 2 * range], [((i+1)/res - 0.5) * 2 * range, (j/res - 0.5) * 2 * range], [((i+1)/res - 0.5) * 2 * range, ((j+1)/res - 0.5) * 2 * range], [(i/res - 0.5) * 2 * range, ((j+1)/res - 0.5) * 2 * range] ]; const outputs = corners.map(p => { const h = activeNeurons.map(n => relu(n.w[0] * p[0] + n.w[1] * p[1] + n.b)); return outputWeights.reduce((sum, w, i) => sum + w * (h[i] || 0), 0) + outputBias; }); if (outputs.some(o => o > 0) && outputs.some(o => o < 0)) { const cx = (corners[0][0] + corners[2][0]) / 2; const cy = (corners[0][1] + corners[2][1]) / 2; const h = activeNeurons.map(n => relu(n.w[0] * cx + n.w[1] * cy + n.b)); while (h.length < 3) h.push(0); const z = (h[0] * 0.4 + h[1] * 0.4 + h[2] * 0.3) * liftAmount; points.push([cx, cy, z]); } } } return points; }, [activeNeurons, liftAmount, outputWeights]); const handleMouseDown = useCallback((view) => (e) => { setDragging(view); setLastMouse([e.clientX, e.clientY]); }, []); const handleMouseMove = useCallback((e) => { if (!dragging) return; const dx = e.clientX - lastMouse[0]; const dy = e.clientY - lastMouse[1]; if (dragging === 'fold') { setFoldRotY(prev => prev + dx * 0.01); setFoldRotX(prev => Math.max(-Math.PI/2, Math.min(Math.PI/2, prev + dy * 0.01))); } else if (dragging === 'hidden') { setHiddenRotY(prev => prev + dx * 0.01); setHiddenRotX(prev => Math.max(-Math.PI/2, Math.min(Math.PI/2, prev + dy * 0.01))); } setLastMouse([e.clientX, e.clientY]); }, [dragging, lastMouse]); const handleMouseUp = useCallback(() => setDragging(null), []); const updateNeuron = (idx, field, subfield, value) => { setNeurons(prev => { const next = [...prev]; if (subfield !== null) { next[idx] = { ...next[idx], [field]: [...next[idx][field]] }; next[idx][field][subfield] = value; } else { next[idx] = { ...next[idx], [field]: value }; } return next; }); }; const handleDatasetChange = (key) => { setSelectedDataset(key); setNeurons(datasets[key].recommendedNeurons); setNeuronCount(3); setDataKey(prev => prev + 1); setHighlightedPoint(null); }; const handleNeuronCountChange = (count) => { setNeuronCount(count); setActiveNeuron(null); }; const neuronColors = ['#ff6b6b', '#4ecdc4', '#ffd93d']; const gridRes = 21; const separationInfo = useMemo(() => { const correct = processedData.filter(d => (d.output > 0) === (d.class === 1)).length; return { accuracy: (correct / processedData.length * 100).toFixed(0), correct, total: processedData.length }; }, [processedData]); return (
Explore how neurons fold 2D input space to make non-linear patterns linearly separable