const { useState, useMemo, useCallback } = React; // Project 3D point to 2D screen const project = (point, rotX, rotY, scale = 100, offset = [200, 200]) => { let [x, y, z] = point; const cosY = Math.cos(rotY), sinY = Math.sin(rotY); [x, z] = [x * cosY - z * sinY, x * sinY + z * cosY]; const cosX = Math.cos(rotX), sinX = Math.sin(rotX); [y, z] = [y * cosX - z * sinX, y * sinX + z * cosX]; return [offset[0] + x * scale, offset[1] - y * scale, z]; }; const relu = (x) => Math.max(0, x); const neuronActivation = (point, w, b) => { const preActivation = w[0] * point[0] + w[1] * point[1] + b; return { pre: preActivation, post: relu(preActivation) }; }; // Dataset generators const datasets = { xor: { name: 'XOR Pattern', description: 'Classic XOR: opposite corners share class', generate: () => { const points = []; const class0Centers = [[-0.7, 0.7], [0.7, -0.7]]; const class1Centers = [[0.7, 0.7], [-0.7, -0.7]]; class0Centers.forEach(center => { for (let i = 0; i < 5; i++) { const angle = (i / 5) * Math.PI * 2 + Math.random() * 0.5; const r = 0.1 + Math.random() * 0.12; points.push({ pos: [center[0] + Math.cos(angle) * r, center[1] + Math.sin(angle) * r], class: 0 }); } }); class1Centers.forEach(center => { for (let i = 0; i < 5; i++) { const angle = (i / 5) * Math.PI * 2 + Math.random() * 0.5; const r = 0.1 + Math.random() * 0.12; points.push({ pos: [center[0] + Math.cos(angle) * r, center[1] + Math.sin(angle) * r], class: 1 }); } }); return points; }, recommendedNeurons: [ { w: [1, 1], b: 0 }, { w: [1, -1], b: 0 }, { w: [0, 1], b: 0 } ] }, circles: { name: 'Concentric Circles', description: 'Inner circle vs outer ring', generate: () => { const points = []; // Inner circle - class 0 for (let i = 0; i < 12; i++) { const angle = (i / 12) * Math.PI * 2 + Math.random() * 0.3; const r = 0.2 + Math.random() * 0.15; points.push({ pos: [Math.cos(angle) * r, Math.sin(angle) * r], class: 0 }); } // Outer ring - class 1 for (let i = 0; i < 14; i++) { const angle = (i / 14) * Math.PI * 2 + Math.random() * 0.2; const r = 0.7 + Math.random() * 0.2; points.push({ pos: [Math.cos(angle) * r, Math.sin(angle) * r], class: 1 }); } return points; }, recommendedNeurons: [ { w: [1, 0], b: 0.4 }, { w: [-1, 0], b: 0.4 }, { w: [0, 1], b: 0.4 } ] }, spiral: { name: 'Two Spirals', description: 'Interleaved spiral arms', generate: () => { const points = []; const turns = 1.5; for (let i = 0; i < 15; i++) { const t = i / 15; const angle = t * turns * Math.PI * 2; const r = 0.2 + t * 0.6; const noise = () => (Math.random() - 0.5) * 0.1; // Spiral 1 points.push({ pos: [Math.cos(angle) * r + noise(), Math.sin(angle) * r + noise()], class: 0 }); // Spiral 2 (offset by PI) points.push({ pos: [Math.cos(angle + Math.PI) * r + noise(), Math.sin(angle + Math.PI) * r + noise()], class: 1 }); } return points; }, recommendedNeurons: [ { w: [1, 0.5], b: 0.2 }, { w: [-0.5, 1], b: 0.2 }, { w: [1, -1], b: 0 } ] }, moons: { name: 'Two Moons', description: 'Interlocking crescent shapes', generate: () => { const points = []; // Upper moon - class 0 for (let i = 0; i < 12; i++) { const angle = (i / 12) * Math.PI; const r = 0.5 + (Math.random() - 0.5) * 0.15; points.push({ pos: [Math.cos(angle) * r, Math.sin(angle) * r], class: 0 }); } // Lower moon - class 1 (shifted and flipped) for (let i = 0; i < 12; i++) { const angle = Math.PI + (i / 12) * Math.PI; const r = 0.5 + (Math.random() - 0.5) * 0.15; points.push({ pos: [Math.cos(angle) * r + 0.5, Math.sin(angle) * r + 0.3], class: 1 }); } return points; }, recommendedNeurons: [ { w: [1, 1], b: 0.3 }, { w: [1, -1], b: 0.3 }, { w: [-1, 0], b: 0.3 } ] }, checkerboard: { name: 'Checkerboard', description: '2x2 alternating grid pattern', generate: () => { const points = []; const centers = [ { pos: [-0.5, 0.5], class: 0 }, { pos: [0.5, 0.5], class: 1 }, { pos: [-0.5, -0.5], class: 1 }, { pos: [0.5, -0.5], class: 0 } ]; centers.forEach(({ pos, class: cls }) => { for (let i = 0; i < 5; i++) { const noise = () => (Math.random() - 0.5) * 0.3; points.push({ pos: [pos[0] + noise(), pos[1] + noise()], class: cls }); } }); return points; }, recommendedNeurons: [ { w: [1, 0], b: 0 }, { w: [0, 1], b: 0 }, { w: [1, 1], b: 0 } ] }, stripe: { name: 'Middle Stripe', description: 'Class 1 forms a diagonal band', generate: () => { const points = []; for (let i = 0; i < 30; i++) { const x = (Math.random() - 0.5) * 2; const y = (Math.random() - 0.5) * 2; const inStripe = Math.abs(x + y) < 0.5; points.push({ pos: [x, y], class: inStripe ? 1 : 0 }); } return points; }, recommendedNeurons: [ { w: [1, 1], b: 0.5 }, { w: [-1, -1], b: 0.5 }, { w: [1, -1], b: 0 } ] } }; function NNGeometryVisualizer() { const [foldRotX, setFoldRotX] = useState(0.4); const [foldRotY, setFoldRotY] = useState(-0.3); const [hiddenRotX, setHiddenRotX] = useState(0.5); const [hiddenRotY, setHiddenRotY] = useState(0.4); const [dragging, setDragging] = useState(null); const [lastMouse, setLastMouse] = useState([0, 0]); const [liftAmount, setLiftAmount] = useState(0); const [showGrid, setShowGrid] = useState(true); const [showDecisionPlane, setShowDecisionPlane] = useState(true); const [activeNeuron, setActiveNeuron] = useState(null); const [highlightedPoint, setHighlightedPoint] = useState(null); const [selectedDataset, setSelectedDataset] = useState('xor'); const [neuronCount, setNeuronCount] = useState(3); const [neurons, setNeurons] = useState(datasets.xor.recommendedNeurons); const [dataKey, setDataKey] = useState(0); // For regenerating data // Output weights adapt to neuron count const outputWeights = useMemo(() => { if (neuronCount === 1) return [1]; if (neuronCount === 2) return [1, -1]; return [0.8, 0.8, 1.2]; }, [neuronCount]); const outputBias = -0.5; // Generate data points const dataPoints = useMemo(() => { return datasets[selectedDataset].generate(); }, [selectedDataset, dataKey]); // Active neurons (limited by count) const activeNeurons = useMemo(() => neurons.slice(0, neuronCount), [neurons, neuronCount]); // Compute data in all spaces const processedData = useMemo(() => { return dataPoints.map((d, idx) => { const activations = activeNeurons.map(n => neuronActivation(d.pos, n.w, n.b)); const h = activations.map(a => a.post); // Pad to 3D for visualization while (h.length < 3) h.push(0); const foldedZ = (h[0] * 0.4 + h[1] * 0.4 + h[2] * 0.3) * liftAmount; const folded = [d.pos[0], d.pos[1], foldedZ]; const hidden = [h[0], h[1], h[2]]; const output = outputWeights.reduce((sum, w, i) => sum + w * (activations[i]?.post || 0), 0) + outputBias; return { ...d, idx, activations, h, folded, hidden, output }; }); }, [dataPoints, activeNeurons, liftAmount, outputWeights]); // Grid for folded view const gridPoints = useMemo(() => { const points = []; const res = 20, range = 1.4; for (let i = 0; i <= res; i++) { for (let j = 0; j <= res; j++) { const x = (i / res - 0.5) * 2 * range; const y = (j / res - 0.5) * 2 * range; const h = activeNeurons.map(n => relu(n.w[0] * x + n.w[1] * y + n.b)); while (h.length < 3) h.push(0); const z = (h[0] * 0.4 + h[1] * 0.4 + h[2] * 0.3) * liftAmount; points.push({ pos: [x, y], lifted: [x, y, z], h }); } } return points; }, [activeNeurons, liftAmount]); // Neuron lines const neuronLines = useMemo(() => { return activeNeurons.map(({ w, b }) => { const range = 2; if (Math.abs(w[1]) > Math.abs(w[0])) { return [[-range, (-w[0] * -range - b) / w[1]], [range, (-w[0] * range - b) / w[1]]]; } else if (Math.abs(w[0]) > 0.01) { return [[(-w[1] * -range - b) / w[0], -range], [(-w[1] * range - b) / w[0], range]]; } return [[-range, -b / (w[1] || 0.01)], [range, -b / (w[1] || 0.01)]]; }); }, [activeNeurons]); // Decision boundary in 2D const decisionBoundary2D = useMemo(() => { const points = []; const res = 35, range = 1.4; for (let i = 0; i < res; i++) { for (let j = 0; j < res; j++) { const corners = [ [(i/res - 0.5) * 2 * range, (j/res - 0.5) * 2 * range], [((i+1)/res - 0.5) * 2 * range, (j/res - 0.5) * 2 * range], [((i+1)/res - 0.5) * 2 * range, ((j+1)/res - 0.5) * 2 * range], [(i/res - 0.5) * 2 * range, ((j+1)/res - 0.5) * 2 * range] ]; const outputs = corners.map(p => { const h = activeNeurons.map(n => relu(n.w[0] * p[0] + n.w[1] * p[1] + n.b)); return outputWeights.reduce((sum, w, i) => sum + w * (h[i] || 0), 0) + outputBias; }); if (outputs.some(o => o > 0) && outputs.some(o => o < 0)) { const cx = (corners[0][0] + corners[2][0]) / 2; const cy = (corners[0][1] + corners[2][1]) / 2; const h = activeNeurons.map(n => relu(n.w[0] * cx + n.w[1] * cy + n.b)); while (h.length < 3) h.push(0); const z = (h[0] * 0.4 + h[1] * 0.4 + h[2] * 0.3) * liftAmount; points.push([cx, cy, z]); } } } return points; }, [activeNeurons, liftAmount, outputWeights]); const handleMouseDown = useCallback((view) => (e) => { setDragging(view); setLastMouse([e.clientX, e.clientY]); }, []); const handleMouseMove = useCallback((e) => { if (!dragging) return; const dx = e.clientX - lastMouse[0]; const dy = e.clientY - lastMouse[1]; if (dragging === 'fold') { setFoldRotY(prev => prev + dx * 0.01); setFoldRotX(prev => Math.max(-Math.PI/2, Math.min(Math.PI/2, prev + dy * 0.01))); } else if (dragging === 'hidden') { setHiddenRotY(prev => prev + dx * 0.01); setHiddenRotX(prev => Math.max(-Math.PI/2, Math.min(Math.PI/2, prev + dy * 0.01))); } setLastMouse([e.clientX, e.clientY]); }, [dragging, lastMouse]); const handleMouseUp = useCallback(() => setDragging(null), []); const updateNeuron = (idx, field, subfield, value) => { setNeurons(prev => { const next = [...prev]; if (subfield !== null) { next[idx] = { ...next[idx], [field]: [...next[idx][field]] }; next[idx][field][subfield] = value; } else { next[idx] = { ...next[idx], [field]: value }; } return next; }); }; const handleDatasetChange = (key) => { setSelectedDataset(key); setNeurons(datasets[key].recommendedNeurons); setNeuronCount(3); setDataKey(prev => prev + 1); setHighlightedPoint(null); }; const handleNeuronCountChange = (count) => { setNeuronCount(count); setActiveNeuron(null); }; const neuronColors = ['#ff6b6b', '#4ecdc4', '#ffd93d']; const gridRes = 21; const separationInfo = useMemo(() => { const correct = processedData.filter(d => (d.output > 0) === (d.class === 1)).length; return { accuracy: (correct / processedData.length * 100).toFixed(0), correct, total: processedData.length }; }, [processedData]); return (

Neural Network Geometry: How Hidden Layers Enable Separation

Explore how neurons fold 2D input space to make non-linear patterns linearly separable

{/* Dataset and config row */}
{/* Dataset selector */}
Dataset — Choose a non-linearly-separable pattern
{Object.entries(datasets).map(([key, { name }]) => ( ))}
{datasets[selectedDataset].description}
{/* Neuron count */}
Hidden Neurons
{[1, 2, 3].map(n => ( ))}
{neuronCount === 1 ? 'Single fold line' : neuronCount === 2 ? 'Two fold lines → 2D hidden' : 'Three fold lines → 3D hidden'}
{/* Lift control + accuracy */}
Lift Amount {(liftAmount * 100).toFixed(0)}%
setLiftAmount(parseFloat(e.target.value))} style={{ width: '100%' }} />
Accuracy 80 ? '#fbbf24' : '#fff' }}> {separationInfo.accuracy}%
{/* Main visualization row */}
{/* 2D Input Space */}
① Input (x, y) Not Separable
{/* Grid */} {Array.from({ length: 7 }).map((_, i) => { const pos = 20 + i * 29; return ( ); })} {/* Neuron lines */} {neuronLines.map((line, i) => { const color = neuronColors[i]; const toScreen = (p) => [107 + p[0] * 60, 107 - p[1] * 60]; const p1 = toScreen(line[0]); const p2 = toScreen(line[1]); const isActive = activeNeuron === i; return ( ); })} {/* Decision boundary */} {decisionBoundary2D.map((p, i) => { const x = 107 + p[0] * 60; const y = 107 - p[1] * 60; return ; })} {/* Data points */} {processedData.map((d, i) => { const x = 107 + d.pos[0] * 60; const y = 107 - d.pos[1] * 60; const color = d.class === 0 ? '#64ffda' : '#ff6b9d'; const isHighlighted = highlightedPoint === i; return ( {isHighlighted && } setHighlightedPoint(i)} onMouseLeave={() => setHighlightedPoint(null)} /> ); })}
Class 0   Class 1
{/* Arrow */}
{/* Hidden Space */}
② Hidden Space ({neuronCount === 1 ? 'h₁' : neuronCount === 2 ? 'h₁, h₂' : 'h₁, h₂, h₃'}) {neuronCount >= 2 ? 'Separable!' : 'Maybe...'}
{/* Decision plane */} {showDecisionPlane && neuronCount >= 2 && ( {(() => { const lines = []; const range = 1.4; const steps = 10; for (let i = 0; i <= steps; i++) { for (let j = 0; j <= steps; j++) { const h1 = (i / steps) * range; const h2 = (j / steps) * range; let h3 = 0; if (neuronCount === 3 && outputWeights[2]) { h3 = -(outputWeights[0] * h1 + outputWeights[1] * h2 + outputBias) / outputWeights[2]; if (h3 < -0.1 || h3 > range * 1.2) continue; } else if (neuronCount === 2) { // For 2D, check if on the decision line const val = outputWeights[0] * h1 + outputWeights[1] * h2 + outputBias; if (Math.abs(val) > 0.15) continue; } const proj = project([h1, h2, h3], hiddenRotX, hiddenRotY, 110, [175, 140]); // Draw connections to neighbors const nextI = (i + 1) / steps * range; const nextJ = (j + 1) / steps * range; if (i < steps) { let h3Next = 0; if (neuronCount === 3 && outputWeights[2]) { h3Next = -(outputWeights[0] * nextI + outputWeights[1] * h2 + outputBias) / outputWeights[2]; } if (h3Next >= -0.1 && h3Next <= range * 1.2) { const projNext = project([nextI, h2, h3Next], hiddenRotX, hiddenRotY, 110, [175, 140]); lines.push( ); } } if (j < steps) { let h3Next = 0; if (neuronCount === 3 && outputWeights[2]) { h3Next = -(outputWeights[0] * h1 + outputWeights[1] * nextJ + outputBias) / outputWeights[2]; } if (h3Next >= -0.1 && h3Next <= range * 1.2) { const projNext = project([h1, nextJ, h3Next], hiddenRotX, hiddenRotY, 110, [175, 140]); lines.push( ); } } } } return lines; })()} )} {/* Data points */} {processedData .map((d, i) => ({ ...d, proj: project(d.hidden, hiddenRotX, hiddenRotY, 110, [175, 140]) })) .sort((a, b) => a.proj[2] - b.proj[2]) .map((d) => { const color = d.class === 0 ? '#64ffda' : '#ff6b9d'; const isHighlighted = highlightedPoint === d.idx; const correct = (d.output > 0) === (d.class === 1); return ( {isHighlighted && ( )} setHighlightedPoint(d.idx)} onMouseLeave={() => setHighlightedPoint(null)} /> ); })} {/* Axes */} {[ { dir: [1.2, 0, 0], color: neuronColors[0], label: 'h₁', show: neuronCount >= 1 }, { dir: [0, 1.2, 0], color: neuronColors[1], label: 'h₂', show: neuronCount >= 2 }, { dir: [0, 0, 1.2], color: neuronColors[2], label: 'h₃', show: neuronCount >= 3 } ].filter(a => a.show).map(({ dir, color, label }) => { const origin = project([0, 0, 0], hiddenRotX, hiddenRotY, 110, [175, 140]); const end = project(dir, hiddenRotX, hiddenRotY, 110, [175, 140]); return ( {label} ); })}
Decision {neuronCount === 2 ? 'line' : 'plane'}
{/* Arrow */}
{/* Folded View */}
③ Folded View (x, y, lift)
{/* Grid mesh */} {showGrid && gridPoints.map((point, idx) => { const i = idx % gridRes, j = Math.floor(idx / gridRes); if (i >= gridRes - 1 || j >= gridRes - 1) return null; const p1 = project(point.lifted, foldRotX, foldRotY, 85, [175, 145]); const p2 = project(gridPoints[idx + 1].lifted, foldRotX, foldRotY, 85, [175, 145]); const p3 = project(gridPoints[idx + gridRes].lifted, foldRotX, foldRotY, 85, [175, 145]); const alpha = 0.05 + Math.max(0, point.lifted[2]) * 0.06; return ( ); })} {/* Fold lines */} {neuronLines.map((line, ni) => { const color = neuronColors[ni]; const isActive = activeNeuron === ni; const segments = []; const steps = 20; for (let s = 0; s < steps; s++) { const t1 = s / steps, t2 = (s + 1) / steps; const p1 = [ line[0][0] + (line[1][0] - line[0][0]) * t1, line[0][1] + (line[1][1] - line[0][1]) * t1 ]; const p2 = [ line[0][0] + (line[1][0] - line[0][0]) * t2, line[0][1] + (line[1][1] - line[0][1]) * t2 ]; const h1 = activeNeurons.map(n => relu(n.w[0] * p1[0] + n.w[1] * p1[1] + n.b)); const h2 = activeNeurons.map(n => relu(n.w[0] * p2[0] + n.w[1] * p2[1] + n.b)); while (h1.length < 3) h1.push(0); while (h2.length < 3) h2.push(0); const z1 = (h1[0] * 0.4 + h1[1] * 0.4 + h1[2] * 0.3) * liftAmount; const z2 = (h2[0] * 0.4 + h2[1] * 0.4 + h2[2] * 0.3) * liftAmount; const proj1 = project([p1[0], p1[1], z1], foldRotX, foldRotY, 85, [175, 145]); const proj2 = project([p2[0], p2[1], z2], foldRotX, foldRotY, 85, [175, 145]); segments.push( ); } return {segments}; })} {/* Decision boundary */} {decisionBoundary2D.map((p, i) => { const proj = project(p, foldRotX, foldRotY, 85, [175, 145]); return ; })} {/* Data points */} {processedData.map((d) => { const proj = project(d.folded, foldRotX, foldRotY, 85, [175, 145]); const color = d.class === 0 ? '#64ffda' : '#ff6b9d'; const isHighlighted = highlightedPoint === d.idx; const correct = (d.output > 0) === (d.class === 1); return ( {isHighlighted && } setHighlightedPoint(d.idx)} onMouseLeave={() => setHighlightedPoint(null)} /> ); })} {/* Axes */} {[ { dir: [1.3, 0, 0], color: '#555', label: 'x' }, { dir: [0, 1.3, 0], color: '#555', label: 'y' }, { dir: [0, 0, 1], color: '#ffd700', label: 'lift' } ].map(({ dir, color, label }) => { const origin = project([0, 0, 0], foldRotX, foldRotY, 85, [175, 145]); const end = project(dir, foldRotX, foldRotY, 85, [175, 145]); return ( {label} ); })}
Original coords + height from neuron outputs
{/* Neuron controls + point inspector */}
{/* Neuron controls */}
Neuron Parameters — Adjust fold lines
{neurons.slice(0, neuronCount).map((neuron, i) => (
setActiveNeuron(activeNeuron === i ? null : i)} >
h{i + 1} {neuron.w[0].toFixed(1)}x{neuron.w[1] >= 0 ? '+' : ''}{neuron.w[1].toFixed(1)}y{neuron.b >= 0 ? '+' : ''}{neuron.b.toFixed(1)}
{['w₁', 'w₂', 'b'].map((label, li) => (
{label}
updateNeuron(i, li < 2 ? 'w' : 'b', li < 2 ? li : null, parseFloat(e.target.value))} onClick={(e) => e.stopPropagation()} style={{ width: '100%' }} />
))}
))} {/* Disabled neurons */} {neurons.slice(neuronCount).map((_, i) => (
h{neuronCount + i + 1} disabled
))}
{/* Point inspector */}
{highlightedPoint !== null ? 'Point Details' : 'Hover a Point'}
{highlightedPoint !== null ? (
True Class {processedData[highlightedPoint].class}
{activeNeurons.map((_, ni) => { const act = processedData[highlightedPoint].activations[ni]; return (
h{ni + 1} = 0 ? '#4ade80' : '#666', width: '40px' }}> {act.pre.toFixed(2)} {act.post.toFixed(2)} {act.pre < 0 && ( zeroed )}
); })}
0) === (processedData[highlightedPoint].class === 1) ? 'rgba(74, 222, 128, 0.1)' : 'rgba(255, 107, 107, 0.1)', borderRadius: '4px' }}> Output: 0 ? '#4ade80' : '#f87171' }}> {processedData[highlightedPoint].output.toFixed(2)} → class {processedData[highlightedPoint].output > 0 ? 1 : 0}
) : (
Hover points to see neuron activations and how ReLU transforms them.
)}
{/* Compact explanation */}
Hidden Space — Each axis = one neuron's ReLU output (distance from line, positive side only)
Folded View — Same transform shown as the input plane being creased along fold lines
Decision Plane — A flat cut in hidden space → complex boundary in input space
); } ReactDOM.render(, document.getElementById('nn-app-root'));