use std::{panic, result::Result, pin::Pin};
use exit_future::Signal;
use log::{debug, error};
use futures::{
	Future, FutureExt, StreamExt,
	future::{select, Either, BoxFuture, join_all, try_join_all, pending},
	sink::SinkExt, task::{Context, Poll},
};
use prometheus_endpoint::{
	exponential_buckets, register,
	PrometheusError,
	CounterVec, HistogramOpts, HistogramVec, Opts, Registry, U64
};
use sp_utils::mpsc::{TracingUnboundedSender, TracingUnboundedReceiver, tracing_unbounded};
use tracing_futures::Instrument;
use crate::{config::{TaskExecutor, TaskType, JoinFuture}, Error};
use sc_telemetry::TelemetrySpan;
mod prometheus_future;
#[cfg(test)]
mod tests;
struct WithTelemetrySpan<T> {
	span: Option<TelemetrySpan>,
	inner: T,
}
impl<T> WithTelemetrySpan<T> {
	fn new(span: Option<TelemetrySpan>, inner: T) -> Self {
		Self {
			span,
			inner,
		}
	}
}
impl<T: Future<Output = ()> + Unpin> Future for WithTelemetrySpan<T> {
	type Output = ();
	fn poll(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
		let span = self.span.clone();
		let _enter = span.as_ref().map(|s| s.enter());
		Pin::new(&mut self.inner).poll(ctx)
	}
}
#[derive(Clone)]
pub struct SpawnTaskHandle {
	on_exit: exit_future::Exit,
	executor: TaskExecutor,
	metrics: Option<Metrics>,
	task_notifier: TracingUnboundedSender<JoinFuture>,
	telemetry_span: Option<TelemetrySpan>,
}
impl SpawnTaskHandle {
	
	
	
	
	
	
	
	
	pub fn spawn(&self, name: &'static str, task: impl Future<Output = ()> + Send + 'static) {
		self.spawn_inner(name, task, TaskType::Async)
	}
	
	pub fn spawn_blocking(&self, name: &'static str, task: impl Future<Output = ()> + Send + 'static) {
		self.spawn_inner(name, task, TaskType::Blocking)
	}
	
	fn spawn_inner(
		&self,
		name: &'static str,
		task: impl Future<Output = ()> + Send + 'static,
		task_type: TaskType,
	) {
		if self.task_notifier.is_closed() {
			debug!("Attempt to spawn a new task has been prevented: {}", name);
			return;
		}
		let on_exit = self.on_exit.clone();
		let metrics = self.metrics.clone();
		
		
		if let Some(metrics) = &self.metrics {
			metrics.tasks_spawned.with_label_values(&[name]).inc();
			
			metrics.tasks_ended.with_label_values(&[name, "finished"]).inc_by(0);
		}
		let future = async move {
			if let Some(metrics) = metrics {
				
				let task = {
					let poll_duration = metrics.poll_duration.with_label_values(&[name]);
					let poll_start = metrics.poll_start.with_label_values(&[name]);
					let inner = prometheus_future::with_poll_durations(poll_duration, poll_start, task);
					
					
					panic::AssertUnwindSafe(inner).catch_unwind()
				};
				futures::pin_mut!(task);
				match select(on_exit, task).await {
					Either::Right((Err(payload), _)) => {
						metrics.tasks_ended.with_label_values(&[name, "panic"]).inc();
						panic::resume_unwind(payload)
					}
					Either::Right((Ok(()), _)) => {
						metrics.tasks_ended.with_label_values(&[name, "finished"]).inc();
					}
					Either::Left(((), _)) => {
						
						metrics.tasks_ended.with_label_values(&[name, "interrupted"]).inc();
					}
				}
			} else {
				futures::pin_mut!(task);
				let _ = select(on_exit, task).await;
			}
		};
		let future = future.in_current_span().boxed();
		let join_handle = self.executor.spawn(
			WithTelemetrySpan::new(self.telemetry_span.clone(), future).boxed(),
			task_type,
		);
		let mut task_notifier = self.task_notifier.clone();
		self.executor.spawn(
			Box::pin(async move {
				if let Err(err) = task_notifier.send(join_handle).await {
					error!("Could not send spawned task handle to queue: {}", err);
				}
			}),
			TaskType::Async,
		);
	}
}
impl sp_core::traits::SpawnNamed for SpawnTaskHandle {
	fn spawn_blocking(&self, name: &'static str, future: BoxFuture<'static, ()>) {
		self.spawn_blocking(name, future);
	}
	fn spawn(&self, name: &'static str, future: BoxFuture<'static, ()>) {
		self.spawn(name, future);
	}
}
pub struct SpawnEssentialTaskHandle {
	essential_failed_tx: TracingUnboundedSender<()>,
	inner: SpawnTaskHandle,
}
impl SpawnEssentialTaskHandle {
	
	pub fn new(
		essential_failed_tx: TracingUnboundedSender<()>,
		spawn_task_handle: SpawnTaskHandle,
	) -> SpawnEssentialTaskHandle {
		SpawnEssentialTaskHandle {
			essential_failed_tx,
			inner: spawn_task_handle,
		}
	}
	
	
	
	pub fn spawn(&self, name: &'static str, task: impl Future<Output = ()> + Send + 'static) {
		self.spawn_inner(name, task, TaskType::Async)
	}
	
	
	
	pub fn spawn_blocking(
		&self,
		name: &'static str,
		task: impl Future<Output = ()> + Send + 'static,
	) {
		self.spawn_inner(name, task, TaskType::Blocking)
	}
	fn spawn_inner(
		&self,
		name: &'static str,
		task: impl Future<Output = ()> + Send + 'static,
		task_type: TaskType,
	) {
		let essential_failed = self.essential_failed_tx.clone();
		let essential_task = std::panic::AssertUnwindSafe(task)
			.catch_unwind()
			.map(move |_| {
				log::error!("Essential task `{}` failed. Shutting down service.", name);
				let _ = essential_failed.close_channel();
			});
		let _ = self.inner.spawn_inner(name, essential_task, task_type);
	}
}
pub struct TaskManager {
	
	
	on_exit: exit_future::Exit,
	
	signal: Option<Signal>,
	
	executor: TaskExecutor,
	
	metrics: Option<Metrics>,
	
	
	essential_failed_tx: TracingUnboundedSender<()>,
	
	essential_failed_rx: TracingUnboundedReceiver<()>,
	
	keep_alive: Box<dyn std::any::Any + Send + Sync>,
	
	task_notifier: TracingUnboundedSender<JoinFuture>,
	
	completion_future: JoinFuture,
	
	
	
	children: Vec<TaskManager>,
	
	telemetry_span: Option<TelemetrySpan>,
}
impl TaskManager {
	
	
	pub(super) fn new(
		executor: TaskExecutor,
		prometheus_registry: Option<&Registry>,
		telemetry_span: Option<TelemetrySpan>,
	) -> Result<Self, PrometheusError> {
		let (signal, on_exit) = exit_future::signal();
		
		let (essential_failed_tx, essential_failed_rx) = tracing_unbounded("mpsc_essential_tasks");
		let metrics = prometheus_registry.map(Metrics::register).transpose()?;
		let (task_notifier, background_tasks) = tracing_unbounded("mpsc_background_tasks");
		
		
		
		let completion_future = executor.spawn(
			Box::pin(background_tasks.for_each_concurrent(None, |x| x)),
			TaskType::Async,
		);
		Ok(Self {
			on_exit,
			signal: Some(signal),
			executor,
			metrics,
			essential_failed_tx,
			essential_failed_rx,
			keep_alive: Box::new(()),
			task_notifier,
			completion_future,
			children: Vec::new(),
			telemetry_span,
		})
	}
	
	pub fn spawn_handle(&self) -> SpawnTaskHandle {
		SpawnTaskHandle {
			on_exit: self.on_exit.clone(),
			executor: self.executor.clone(),
			metrics: self.metrics.clone(),
			task_notifier: self.task_notifier.clone(),
			telemetry_span: self.telemetry_span.clone(),
		}
	}
	
	pub fn spawn_essential_handle(&self) -> SpawnEssentialTaskHandle {
		SpawnEssentialTaskHandle::new(self.essential_failed_tx.clone(), self.spawn_handle())
	}
	
	
	
	
	
	
	
	
	
	pub fn clean_shutdown(mut self) -> Pin<Box<dyn Future<Output = ()> + Send>> {
		self.terminate();
		let children_shutdowns = self.children.into_iter().map(|x| x.clean_shutdown());
		let keep_alive = self.keep_alive;
		let completion_future = self.completion_future;
		Box::pin(async move {
			join_all(children_shutdowns).await;
			completion_future.await;
			drop(keep_alive);
		})
	}
	
	
	
	
	
	
	
	pub fn future<'a>(&'a mut self) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
		Box::pin(async move {
			let mut t1 = self.essential_failed_rx.next().fuse();
			let mut t2 = self.on_exit.clone().fuse();
			let mut t3 = try_join_all(
				self.children.iter_mut().map(|x| x.future())
					
					
					.chain(std::iter::once(pending().boxed()))
			).fuse();
			futures::select! {
				_ = t1 => Err(Error::Other("Essential task failed.".into())),
				_ = t2 => Ok(()),
				res = t3 => Err(res.map(|_| ()).expect_err("this future never ends; qed")),
			}
		})
	}
	
	pub fn terminate(&mut self) {
		if let Some(signal) = self.signal.take() {
			let _ = signal.fire();
			
			self.task_notifier.close_channel();
			for child in self.children.iter_mut() {
				child.terminate();
			}
		}
	}
	
	pub fn keep_alive<T: 'static + Send + Sync>(&mut self, to_keep_alive: T) {
		
		use std::mem;
		let old = mem::replace(&mut self.keep_alive, Box::new(()));
		self.keep_alive = Box::new((to_keep_alive, old));
	}
	
	
	
	pub fn add_child(&mut self, child: TaskManager) {
		self.children.push(child);
	}
}
#[derive(Clone)]
struct Metrics {
	
	poll_duration: HistogramVec,
	poll_start: CounterVec<U64>,
	tasks_spawned: CounterVec<U64>,
	tasks_ended: CounterVec<U64>,
}
impl Metrics {
	fn register(registry: &Registry) -> Result<Self, PrometheusError> {
		Ok(Self {
			poll_duration: register(HistogramVec::new(
				HistogramOpts {
					common_opts: Opts::new(
						"tasks_polling_duration",
						"Duration in seconds of each invocation of Future::poll"
					),
					buckets: exponential_buckets(0.001, 4.0, 9)
						.expect("function parameters are constant and always valid; qed"),
				},
				&["task_name"]
			)?, registry)?,
			poll_start: register(CounterVec::new(
				Opts::new(
					"tasks_polling_started_total",
					"Total number of times we started invoking Future::poll"
				),
				&["task_name"]
			)?, registry)?,
			tasks_spawned: register(CounterVec::new(
				Opts::new(
					"tasks_spawned_total",
					"Total number of tasks that have been spawned on the Service"
				),
				&["task_name"]
			)?, registry)?,
			tasks_ended: register(CounterVec::new(
				Opts::new(
					"tasks_ended_total",
					"Total number of tasks for which Future::poll has returned Ready(()) or panicked"
				),
				&["task_name", "reason"]
			)?, registry)?,
		})
	}
}