#[cfg(test)]
mod validation;
use std::cmp::min;
use std::mem;
use std::vec::Vec;
use parity_wasm::{elements, builder};
use rules;
pub fn update_call_index(instructions: &mut elements::Instructions, inserted_index: u32) {
	use parity_wasm::elements::Instruction::*;
	for instruction in instructions.elements_mut().iter_mut() {
		if let &mut Call(ref mut call_index) = instruction {
			if *call_index >= inserted_index { *call_index += 1}
		}
	}
}
#[derive(Debug)]
struct ControlBlock {
	
	
	
	
	
	
	
	
	lowest_forward_br_target: usize,
	
	active_metered_block: MeteredBlock,
	
	
	is_loop: bool,
}
#[derive(Debug)]
pub(crate) struct MeteredBlock {
	
	start_pos: usize,
	
	cost: u32,
}
struct Counter {
	
	
	
	
	
	stack: Vec<ControlBlock>,
	
	finalized_blocks: Vec<MeteredBlock>,
}
impl Counter {
	fn new() -> Counter {
		Counter {
			stack: Vec::new(),
			finalized_blocks: Vec::new(),
		}
	}
	
	fn begin_control_block(&mut self, cursor: usize, is_loop: bool) {
		let index = self.stack.len();
		self.stack.push(ControlBlock {
			lowest_forward_br_target: index,
			active_metered_block: MeteredBlock {
				start_pos: cursor,
				cost: 0,
			},
			is_loop,
		})
	}
	
	
	fn finalize_control_block(&mut self, cursor: usize) -> Result<(), ()> {
		
		
		self.finalize_metered_block(cursor)?;
		
		let closing_control_block = self.stack.pop().ok_or_else(|| ())?;
		let closing_control_index = self.stack.len();
		if self.stack.is_empty() {
			return Ok(())
		}
		
		{
			let control_block = self.stack.last_mut().ok_or_else(|| ())?;
			control_block.lowest_forward_br_target = min(
				control_block.lowest_forward_br_target,
				closing_control_block.lowest_forward_br_target
			);
		}
		
		
		let may_br_out = closing_control_block.lowest_forward_br_target < closing_control_index;
		if may_br_out {
			self.finalize_metered_block(cursor)?;
		}
		Ok(())
	}
	
	
	
	fn finalize_metered_block(&mut self, cursor: usize) -> Result<(), ()> {
		let closing_metered_block = {
			let control_block = self.stack.last_mut().ok_or_else(|| ())?;
			mem::replace(
				&mut control_block.active_metered_block,
				MeteredBlock {
					start_pos: cursor + 1,
					cost: 0,
				}
			)
		};
		
		
		
		
		
		let last_index = self.stack.len() - 1;
		if last_index > 0 {
			let prev_control_block = self.stack.get_mut(last_index - 1)
				.expect("last_index is greater than 0; last_index is stack size - 1; qed");
			let prev_metered_block = &mut prev_control_block.active_metered_block;
			if closing_metered_block.start_pos == prev_metered_block.start_pos {
				prev_metered_block.cost += closing_metered_block.cost;
				return Ok(())
			}
		}
		if closing_metered_block.cost > 0 {
			self.finalized_blocks.push(closing_metered_block);
		}
		Ok(())
	}
	
	
	
	
	fn branch(&mut self, cursor: usize, indices: &[usize]) -> Result<(), ()> {
		self.finalize_metered_block(cursor)?;
		
		for &index in indices {
			let target_is_loop = {
				let target_block = self.stack.get(index).ok_or_else(|| ())?;
				target_block.is_loop
			};
			if target_is_loop {
				continue;
			}
			let control_block = self.stack.last_mut().ok_or_else(|| ())?;
			control_block.lowest_forward_br_target =
				min(control_block.lowest_forward_br_target, index);
		}
		Ok(())
	}
	
	fn active_control_block_index(&self) -> Option<usize> {
		self.stack.len().checked_sub(1)
	}
	
	fn active_metered_block(&mut self) -> Result<&mut MeteredBlock, ()> {
		let top_block = self.stack.last_mut().ok_or_else(|| ())?;
		Ok(&mut top_block.active_metered_block)
	}
	
	fn increment(&mut self, val: u32) -> Result<(), ()> {
		let top_block = self.active_metered_block()?;
		top_block.cost = top_block.cost.checked_add(val).ok_or_else(|| ())?;
		Ok(())
	}
}
fn inject_grow_counter(instructions: &mut elements::Instructions, grow_counter_func: u32) -> usize {
	use parity_wasm::elements::Instruction::*;
	let mut counter = 0;
	for instruction in instructions.elements_mut() {
		if let GrowMemory(_) = *instruction {
			*instruction = Call(grow_counter_func);
			counter += 1;
		}
	}
	counter
}
fn add_grow_counter(module: elements::Module, rules: &rules::Set, gas_func: u32) -> elements::Module {
	use parity_wasm::elements::Instruction::*;
	let mut b = builder::from_module(module);
	b.push_function(
		builder::function()
			.signature().params().i32().build().with_return_type(Some(elements::ValueType::I32)).build()
			.body()
				.with_instructions(elements::Instructions::new(vec![
					GetLocal(0),
					GetLocal(0),
					I32Const(rules.grow_cost() as i32),
					I32Mul,
					
					Call(gas_func),
					GrowMemory(0),
					End,
				]))
				.build()
			.build()
	);
	b.build()
}
pub(crate) fn determine_metered_blocks(
	instructions: &elements::Instructions,
	rules: &rules::Set,
) -> Result<Vec<MeteredBlock>, ()> {
	use parity_wasm::elements::Instruction::*;
	let mut counter = Counter::new();
	
	counter.begin_control_block(0, false);
	for cursor in 0..instructions.elements().len() {
		let instruction = &instructions.elements()[cursor];
		let instruction_cost = rules.process(instruction)?;
		match *instruction {
			Block(_) => {
				counter.increment(instruction_cost)?;
				
				
				
				
				let top_block_start_pos = counter.active_metered_block()?.start_pos;
				counter.begin_control_block(top_block_start_pos, false);
			}
			If(_) => {
				counter.increment(instruction_cost)?;
				counter.begin_control_block(cursor + 1, false);
			}
			Loop(_) => {
				counter.increment(instruction_cost)?;
				counter.begin_control_block(cursor + 1, true);
			}
			End => {
				counter.finalize_control_block(cursor)?;
			},
			Else => {
				counter.finalize_metered_block(cursor)?;
			}
			Br(label) | BrIf(label) => {
				counter.increment(instruction_cost)?;
				
				let active_index = counter.active_control_block_index().ok_or_else(|| ())?;
				let target_index = active_index.checked_sub(label as usize).ok_or_else(|| ())?;
				counter.branch(cursor, &[target_index])?;
			}
			BrTable(ref br_table_data) => {
				counter.increment(instruction_cost)?;
				let active_index = counter.active_control_block_index().ok_or_else(|| ())?;
				let target_indices = [br_table_data.default]
					.iter()
					.chain(br_table_data.table.iter())
					.map(|label| active_index.checked_sub(*label as usize))
					.collect::<Option<Vec<_>>>()
					.ok_or_else(|| ())?;
				counter.branch(cursor, &target_indices)?;
			}
			Return => {
				counter.increment(instruction_cost)?;
				counter.branch(cursor, &[0])?;
			}
			_ => {
				
				counter.increment(instruction_cost)?;
			}
		}
	}
	counter.finalized_blocks.sort_unstable_by_key(|block| block.start_pos);
	Ok(counter.finalized_blocks)
}
pub fn inject_counter(
	instructions: &mut elements::Instructions,
	rules: &rules::Set,
	gas_func: u32,
) -> Result<(), ()> {
	let blocks = determine_metered_blocks(instructions, rules)?;
	insert_metering_calls(instructions, blocks, gas_func)
}
fn insert_metering_calls(
	instructions: &mut elements::Instructions,
	blocks: Vec<MeteredBlock>,
	gas_func: u32,
)
	-> Result<(), ()>
{
	use parity_wasm::elements::Instruction::*;
	
	
	let new_instrs_len = instructions.elements().len() + 2 * blocks.len();
	let original_instrs = mem::replace(
		instructions.elements_mut(), Vec::with_capacity(new_instrs_len)
	);
	let new_instrs = instructions.elements_mut();
	let mut block_iter = blocks.into_iter().peekable();
	for (original_pos, instr) in original_instrs.into_iter().enumerate() {
		
		let used_block = if let Some(ref block) = block_iter.peek() {
			if block.start_pos == original_pos {
				new_instrs.push(I32Const(block.cost as i32));
				new_instrs.push(Call(gas_func));
				true
			} else { false }
		} else { false };
		if used_block {
			block_iter.next();
		}
		
		new_instrs.push(instr);
	}
	if block_iter.next().is_some() {
		return Err(());
	}
	Ok(())
}
pub fn inject_gas_counter(module: elements::Module, rules: &rules::Set, gas_module_name: &str)
	-> Result<elements::Module, elements::Module>
{
	
	let mut mbuilder = builder::from_module(module);
	let import_sig = mbuilder.push_signature(
		builder::signature()
			.param().i32()
			.build_sig()
		);
	mbuilder.push_import(
		builder::import()
			.module(gas_module_name)
			.field("gas")
			.external().func(import_sig)
			.build()
		);
	
	let mut module = mbuilder.build();
	
	
	let gas_func = module.import_count(elements::ImportCountType::Function) as u32 - 1;
	let total_func = module.functions_space() as u32;
	let mut need_grow_counter = false;
	let mut error = false;
	
	for section in module.sections_mut() {
		match section {
			&mut elements::Section::Code(ref mut code_section) => {
				for ref mut func_body in code_section.bodies_mut() {
					update_call_index(func_body.code_mut(), gas_func);
					if let Err(_) = inject_counter(func_body.code_mut(), rules, gas_func) {
						error = true;
						break;
					}
					if rules.grow_cost() > 0 {
						if inject_grow_counter(func_body.code_mut(), total_func) > 0 {
							need_grow_counter = true;
						}
					}
				}
			},
			&mut elements::Section::Export(ref mut export_section) => {
				for ref mut export in export_section.entries_mut() {
					if let &mut elements::Internal::Function(ref mut func_index) = export.internal_mut() {
						if *func_index >= gas_func { *func_index += 1}
					}
				}
			},
			&mut elements::Section::Element(ref mut elements_section) => {
				
				
				for ref mut segment in elements_section.entries_mut() {
					
					for func_index in segment.members_mut() {
						if *func_index >= gas_func { *func_index += 1}
					}
				}
			},
			&mut elements::Section::Start(ref mut start_idx) => {
				if *start_idx >= gas_func { *start_idx += 1}
			},
			_ => { }
		}
	}
	if error { return Err(module); }
	if need_grow_counter { Ok(add_grow_counter(module, rules, gas_func)) } else { Ok(module) }
}
#[cfg(test)]
mod tests {
	extern crate wabt;
	use parity_wasm::{serialize, builder, elements};
	use parity_wasm::elements::Instruction::*;
	use super::*;
	use rules;
	pub fn get_function_body(module: &elements::Module, index: usize)
		-> Option<&[elements::Instruction]>
	{
		module.code_section()
			.and_then(|code_section| code_section.bodies().get(index))
			.map(|func_body| func_body.code().elements())
	}
	#[test]
	fn simple_grow() {
		let module = builder::module()
			.global()
				.value_type().i32()
				.build()
			.function()
				.signature().param().i32().build()
				.body()
					.with_instructions(elements::Instructions::new(
						vec![
							GetGlobal(0),
							GrowMemory(0),
							End
						]
					))
					.build()
				.build()
			.build();
		let injected_module = inject_gas_counter(
			module,
			&rules::Set::default().with_grow_cost(10000),
			"env",
		).unwrap();
		assert_eq!(
			get_function_body(&injected_module, 0).unwrap(),
			&vec![
				I32Const(2),
				Call(0),
				GetGlobal(0),
				Call(2),
				End
			][..]
		);
		assert_eq!(
			get_function_body(&injected_module, 1).unwrap(),
			&vec![
				GetLocal(0),
				GetLocal(0),
				I32Const(10000),
				I32Mul,
				Call(0),
				GrowMemory(0),
				End,
			][..]
		);
		let binary = serialize(injected_module).expect("serialization failed");
		self::wabt::wasm2wat(&binary).unwrap();
	}
	#[test]
	fn grow_no_gas_no_track() {
		let module = builder::module()
			.global()
				.value_type().i32()
				.build()
			.function()
				.signature().param().i32().build()
				.body()
					.with_instructions(elements::Instructions::new(
						vec![
							GetGlobal(0),
							GrowMemory(0),
							End
						]
					))
					.build()
				.build()
			.build();
		let injected_module = inject_gas_counter(module, &rules::Set::default(), "env").unwrap();
		assert_eq!(
			get_function_body(&injected_module, 0).unwrap(),
			&vec![
				I32Const(2),
				Call(0),
				GetGlobal(0),
				GrowMemory(0),
				End
			][..]
		);
		assert_eq!(injected_module.functions_space(), 2);
		let binary = serialize(injected_module).expect("serialization failed");
		self::wabt::wasm2wat(&binary).unwrap();
	}
	#[test]
	fn call_index() {
		let module = builder::module()
			.global()
				.value_type().i32()
				.build()
			.function()
				.signature().param().i32().build()
				.body().build()
				.build()
			.function()
				.signature().param().i32().build()
				.body()
					.with_instructions(elements::Instructions::new(
						vec![
							Call(0),
							If(elements::BlockType::NoResult),
								Call(0),
								Call(0),
								Call(0),
							Else,
								Call(0),
								Call(0),
							End,
							Call(0),
							End
						]
					))
					.build()
				.build()
			.build();
		let injected_module = inject_gas_counter(module, &Default::default(), "env").unwrap();
		assert_eq!(
			get_function_body(&injected_module, 1).unwrap(),
			&vec![
				I32Const(3),
				Call(0),
				Call(1),
				If(elements::BlockType::NoResult),
					I32Const(3),
					Call(0),
					Call(1),
					Call(1),
					Call(1),
				Else,
					I32Const(2),
					Call(0),
					Call(1),
					Call(1),
				End,
				Call(1),
				End
			][..]
		);
	}
	#[test]
	fn forbidden() {
		let module = builder::module()
			.global()
				.value_type().i32()
				.build()
			.function()
				.signature().param().i32().build()
				.body()
					.with_instructions(elements::Instructions::new(
						vec![
							F32Const(555555),
							End
						]
					))
					.build()
				.build()
			.build();
		let rules = rules::Set::default().with_forbidden_floats();
		if let Err(_) = inject_gas_counter(module, &rules, "env") { }
		else { panic!("Should be error because of the forbidden operation")}
	}
	fn parse_wat(source: &str) -> elements::Module {
		let module_bytes = wabt::Wat2Wasm::new()
			.validate(false)
			.convert(source)
			.expect("failed to parse module");
		elements::deserialize_buffer(module_bytes.as_ref())
			.expect("failed to parse module")
	}
	macro_rules! test_gas_counter_injection {
		(name = $name:ident; input = $input:expr; expected = $expected:expr) => {
			#[test]
			fn $name() {
				let input_module = parse_wat($input);
				let expected_module = parse_wat($expected);
				let injected_module = inject_gas_counter(input_module, &Default::default(), "env")
					.expect("inject_gas_counter call failed");
				let actual_func_body = get_function_body(&injected_module, 0)
					.expect("injected module must have a function body");
				let expected_func_body = get_function_body(&expected_module, 0)
					.expect("post-module must have a function body");
				assert_eq!(actual_func_body, expected_func_body);
			}
		}
	}
	test_gas_counter_injection! {
		name = simple;
		input = r#"
		(module
			(func (result i32)
				(get_global 0)))
		"#;
		expected = r#"
		(module
			(func (result i32)
				(call 0 (i32.const 1))
				(get_global 0)))
		"#
	}
	test_gas_counter_injection! {
		name = nested;
		input = r#"
		(module
			(func (result i32)
				(get_global 0)
				(block
					(get_global 0)
					(get_global 0)
					(get_global 0))
				(get_global 0)))
		"#;
		expected = r#"
		(module
			(func (result i32)
				(call 0 (i32.const 6))
				(get_global 0)
				(block
					(get_global 0)
					(get_global 0)
					(get_global 0))
				(get_global 0)))
		"#
	}
	test_gas_counter_injection! {
		name = ifelse;
		input = r#"
		(module
			(func (result i32)
				(get_global 0)
				(if
					(then
						(get_global 0)
						(get_global 0)
						(get_global 0))
					(else
						(get_global 0)
						(get_global 0)))
				(get_global 0)))
		"#;
		expected = r#"
		(module
			(func (result i32)
				(call 0 (i32.const 3))
				(get_global 0)
				(if
					(then
						(call 0 (i32.const 3))
						(get_global 0)
						(get_global 0)
						(get_global 0))
					(else
						(call 0 (i32.const 2))
						(get_global 0)
						(get_global 0)))
				(get_global 0)))
		"#
	}
	test_gas_counter_injection! {
		name = branch_innermost;
		input = r#"
		(module
			(func (result i32)
				(get_global 0)
				(block
					(get_global 0)
					(drop)
					(br 0)
					(get_global 0)
					(drop))
				(get_global 0)))
		"#;
		expected = r#"
		(module
			(func (result i32)
				(call 0 (i32.const 6))
				(get_global 0)
				(block
					(get_global 0)
					(drop)
					(br 0)
					(call 0 (i32.const 2))
					(get_global 0)
					(drop))
				(get_global 0)))
		"#
	}
	test_gas_counter_injection! {
		name = branch_outer_block;
		input = r#"
		(module
			(func (result i32)
				(get_global 0)
				(block
					(get_global 0)
					(if
						(then
							(get_global 0)
							(get_global 0)
							(drop)
							(br_if 1)))
					(get_global 0)
					(drop))
				(get_global 0)))
		"#;
		expected = r#"
		(module
			(func (result i32)
				(call 0 (i32.const 5))
				(get_global 0)
				(block
					(get_global 0)
					(if
						(then
							(call 0 (i32.const 4))
							(get_global 0)
							(get_global 0)
							(drop)
							(br_if 1)))
					(call 0 (i32.const 2))
					(get_global 0)
					(drop))
				(get_global 0)))
		"#
	}
	test_gas_counter_injection! {
		name = branch_outer_loop;
		input = r#"
		(module
			(func (result i32)
				(get_global 0)
				(loop
					(get_global 0)
					(if
						(then
							(get_global 0)
							(br_if 0))
						(else
							(get_global 0)
							(get_global 0)
							(drop)
							(br_if 1)))
					(get_global 0)
					(drop))
				(get_global 0)))
		"#;
		expected = r#"
		(module
			(func (result i32)
				(call 0 (i32.const 3))
				(get_global 0)
				(loop
					(call 0 (i32.const 4))
					(get_global 0)
					(if
						(then
							(call 0 (i32.const 2))
							(get_global 0)
							(br_if 0))
						(else
							(call 0 (i32.const 4))
							(get_global 0)
							(get_global 0)
							(drop)
							(br_if 1)))
					(get_global 0)
					(drop))
				(get_global 0)))
		"#
	}
	test_gas_counter_injection! {
		name = return_from_func;
		input = r#"
		(module
			(func (result i32)
				(get_global 0)
				(if
					(then
						(return)))
				(get_global 0)))
		"#;
		expected = r#"
		(module
			(func (result i32)
				(call 0 (i32.const 2))
				(get_global 0)
				(if
					(then
						(call 0 (i32.const 1))
						(return)))
				(call 0 (i32.const 1))
				(get_global 0)))
		"#
	}
	test_gas_counter_injection! {
		name = branch_from_if_not_else;
		input = r#"
		(module
			(func (result i32)
				(get_global 0)
				(block
					(get_global 0)
					(if
						(then (br 1))
						(else (br 0)))
					(get_global 0)
					(drop))
				(get_global 0)))
		"#;
		expected = r#"
		(module
			(func (result i32)
				(call 0 (i32.const 5))
				(get_global 0)
				(block
					(get_global 0)
					(if
						(then
							(call 0 (i32.const 1))
							(br 1))
						(else
							(call 0 (i32.const 1))
							(br 0)))
					(call 0 (i32.const 2))
					(get_global 0)
					(drop))
				(get_global 0)))
		"#
	}
}