
local bit = lazyRequire ('bit')
local Match = plugins.style:lazyRequire('selectors.matching.Match')
local Counters = plugins.style:lazyRequire('profiling.Counters')
local RuleApplicationStages = plugins.style:lazyRequire('rules.RuleApplicationStages')
local CandidateRuleGatherer = plugins.style:lazyRequire('selectors.matching.CandidateRuleGatherer')

local MatchingEngine = class(function(self, matchabilityCache, mutationTracker)

	self._cache = matchabilityCache
	self._mutationTracker = mutationTracker
	self._candidateRuleGatherer = CandidateRuleGatherer.new()

end)

-- Given a node and a RuleSet, returns a list of all rules that matched the
-- node, as well as the specific expression that was matched (this is required
-- because the order in which the rules are applied will depend on the specificity
-- of each rule relative to each other).
function MatchingEngine:getMatchingRules(node, ruleSet)
		
	self:resetMatchabilityInfo()

	local key = node:getMatchabilityKey()
	local partition = node:getPartitionAsHash()
	local matchesForKey = self._cache:getMatchesForMatchabilityKey(partition, key)
	
	self:_flagProgress(node, RuleApplicationStages.MATCHABILITY_CACHE_CHECK)

	-- If this node has the exact same matchability as one we've already matched,
	-- we can just bail here and return the same matches.
	if matchesForKey ~= nil then
		Counters.MATCHABILITY_CACHE_HITS:increment()
		return matchesForKey
	else
		Counters.MATCHABILITY_CACHE_MISSES:increment()
	end
	
	self:_flagProgress(node, RuleApplicationStages.MATCHING_ALGORITHM)
	
	-- Generate a list of rough matches which can be used as candidates for the
	-- full matching algorithm (which is more expensive).
	local candidates = self._candidateRuleGatherer:getCandidateRules(node, ruleSet)

	-- Run the full matching algorithm over the candidates to get exact matches.
	local matches = self:_getExactMatches(node, candidates)

	self._cache:setMatchesForMatchabilityKey(partition, key, matches)

	return matches

end

function MatchingEngine:_flagProgress(node, stage)
	self._mutationTracker:flagProgressForNode(node, stage)
end

function MatchingEngine:resetMatchabilityInfo()
	self._node = nil
	self._tagArray = nil
	self._nodeTypeArray = nil
	self._tagCountArray = nil
	self._bloomArray = nil
end

-- Checks the supplied list of rules (pre-filtered by CandidateRuleGatherer) 
-- to create a list of exact matches by checking each symbol and combinator 
-- that makes up the selector.
function MatchingEngine:_getExactMatches(node, candidates)

	local matches = {}

	for i = 1, #candidates do
		local rule = candidates[i]

		for j = 1, #rule.selector.expressions do
			Counters.TOTAL_SELECTORS:increment()
		
			local expr = rule.selector.expressions[j]

			if self:nodeMatchesExpression(node, expr) then
				table.insert(matches, Match.new(rule, expr))
			end
		end
	end
	
	Counters.TOTAL_MATCHES:increment(#matches)

	return matches
end

function MatchingEngine:nodeMatchesExpression(node, expr)
	self:_getFastMatchabilityInfoFromNode(node)

	local segmentIndex = #expr.segments
	local combinationMode

	-- these arrays are coming from C++ so are 0-based
	self._currentNodeArrayIndex = 0
	self._currentTagArrayBase = 0

	while segmentIndex > 0 do
		local segment = expr.segments[segmentIndex]

		if not self:_nodeMatchesSegment(segment, combinationMode) then
			return false
		end
		
		segmentIndex = segmentIndex - 1
		combinationMode = segment.combinationMode
	end

	return true
end

-- Gets the C-based integer arrays from the native SceneNode that
-- can be used to match tags quickly
function MatchingEngine:_getFastMatchabilityInfoFromNode(node)
	if node ~= self._node then
		self._node = node
		self._tagArray = node:getMatchabilityTagList()
		self._nodeTypeArray = node:getMatchabilityNodeTypeList()
		self._tagCountArray = node:getMatchabilityTagCountList()
		self._bloomArray = node:getMatchabilityTagBloomFilterList()
	end
end

-- uses the matchability info from the native SceneNode to match 
-- a rule segment to the current node quickly
function MatchingEngine:_nodeMatchesSegment(segment, combinationMode)
	-- default to no maximum depth (ie match rules
	-- of any descendants, however far away)
	local maxDepth = self:_getDepthForCombinationMode(combinationMode)
	local depth = 1
	local foundSegment = false

	while not foundSegment and
		self._tagCountArray[self._currentNodeArrayIndex] ~= 0xffffffff and
		(depth <= maxDepth) do

		if self:_canEliminateMatchUsingBloomFilter(segment) then
			break
		end

		foundSegment = 
				self:_tagArrayContainsAllTags(segment.tags) and 
				self:_nodeTypeIsCorrect(segment)

		depth = depth + 1
		self:_incrementFastArrayIndex()
	end

	return foundSegment
end

function MatchingEngine:_getDepthForCombinationMode(combinationMode)
	if combinationMode == nil or combinationMode == 'child' then
		return 1
	elseif combinationMode == 'descendant' then
		return math.huge
	else
		error('Unrecognized combination mode: ' .. combinationMode)
	end
end

-- Uses pregenerated bloom filter bitmasks from the current node and the 
-- supplied rule segment to eliminate possible matches more quickly
function MatchingEngine:_canEliminateMatchUsingBloomFilter(segment)
	local result = bit.band(self._bloomArray[self._currentNodeArrayIndex],
							segment.bloom) ~= segment.bloom

	if result then
		Counters.SELECTORS_ELIMINATED_BY_BLOOM_FILTER:increment()
	end

	return result
end

-- Searches the current node's matchability info tag list for all
-- the supplied tags
function MatchingEngine:_tagArrayContainsAllTags(tags)
	local count = self._tagCountArray[self._currentNodeArrayIndex] 
	if count < #tags then
		Counters.SELECTORS_ELIMINATED_BY_TAG_COUNT:increment()
		return false
	end

	-- make sure we search just the part of the tag list that
	-- represents the current node
	local startIndex = self._currentTagArrayBase
	local endIndex = startIndex + count - 1

	for j = 1, #tags do
		local tag = tags[j]
		local found = false
		for i = startIndex, endIndex do
			if tag.hash == self._tagArray[i] then
				found = true
				break
			end
		end
		if found == false then
			Counters.SELECTORS_ELIMINATED_BY_TAG_MISMATCH:increment()
			return false
		end
	end

	return true
end

function MatchingEngine:_nodeTypeIsCorrect(segment)
	local result = 
			segment.nodeType == "*" or 
			segment.nodeTypeHash == self._nodeTypeArray[self._currentNodeArrayIndex]
	
	if result == false then
		Counters.SELECTORS_ELIMINATED_BY_TYPE_MISMATCH:increment()
	end
	
	return result
end

-- Increments the base and index values of the current
-- node's matchability info arrays
function MatchingEngine:_incrementFastArrayIndex()
	self._currentTagArrayBase = 
		self._currentTagArrayBase + 
		self._tagCountArray[self._currentNodeArrayIndex]

	self._currentNodeArrayIndex = self._currentNodeArrayIndex + 1
end

return MatchingEngine
