/* eslint-disable react/no-unknown-property */
import React, { useRef, useMemo } from "react";
import * as THREE from "three";
import { extend, createPortal, Node } from "@react-three/fiber";
import { geometry } from "maath";
import SimulationMaterial from "./SimulationMaterial";
import ColorMaterial from "./ColorMaterial";

extend(geometry);
extend({ SimulationMaterial, ColorMaterial });

// Add types to ThreeElements elements so primitives pick up on it
declare module "@react-three/fiber" {
	interface ThreeElements {
		simulationMaterial: Node<SimulationMaterial, typeof SimulationMaterial>,
		colorMaterial: Node<ColorMaterial, typeof ColorMaterial>,
	}
}

const useParticlesTexture = (
	size: number,
	stateTransitions: number[],
	particleStates:  { particles: Float32Array; id: string; }[],
	colorTransitions: number[],
	particleColors:  { particles: Float32Array; id: string; }[],
) => {
	const simulationMaterialRef = useRef<SimulationMaterial>(null);
	const colorMaterialRef = useRef<ColorMaterial>(null);

	const materialScene = useMemo(() => new THREE.Scene(), []);
	const colorScene = useMemo(() => new THREE.Scene(), []);
	
	const buffers = useMemo(() => {
		const positions = new Float32Array([-1, -1, 0, 1, -1, 0, 1, 1, 0, -1, -1, 0, 1, 1, 0, -1, 1, 0]);
		const uvs = new Float32Array([0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0]);
		return (
			<bufferGeometry>
				<bufferAttribute attach="attributes-position" count={positions.length / 3} array={positions} itemSize={3} />
				<bufferAttribute attach="attributes-uv" count={uvs.length / 2} array={uvs} itemSize={2} />
			</bufferGeometry>
		);
	}, []);
	const positionsPortal = useMemo(() => createPortal(
		<mesh>
			<simulationMaterial ref={simulationMaterialRef} args={[size, particleStates, stateTransitions]} />
			{buffers}
		</mesh>,
		materialScene,
	), [size, particleStates, stateTransitions, buffers]);
	const colorsPortal = useMemo(() => createPortal(
		<mesh>
			<colorMaterial ref={colorMaterialRef} args={[size, particleColors, colorTransitions]} />
			{buffers}
		</mesh>,
		colorScene,
	), [size, particleColors, colorTransitions, buffers]);

	const ParticleTexture = useMemo(() => (
		<>
			{positionsPortal}
			{colorsPortal}
		</>
	), [positionsPortal, colorsPortal]);

	return {
		ParticleTexture,
		simulationMaterialRef,
		materialScene,
		colorMaterialRef,
		colorScene,
	};
};

export default useParticlesTexture;
