local highlighter = require "vim.treesitter.highlighter"
local parsers = require "nvim-treesitter.parsers"
local ts = vim.treesitter

local COMMENT_NODES = {
  markdown = "html_block",
  haskell = "haddock",
}

local function check_assertions(file)
  local buf = vim.fn.bufadd(file)
  vim.fn.bufload(file)
  local lang = parsers.get_buf_lang(buf)
  assert.same(
    1,
    vim.fn.executable "highlight-assertions",
    '"highlight-assertions" not executable!'
      .. ' Get it via "cargo install --git https://github.com/theHamsta/highlight-assertions"'
  )
  local comment_node = COMMENT_NODES[lang] or "comment"
  local assertions = vim.fn.json_decode(
    vim.fn.system(
      "highlight-assertions -p '"
        .. vim.api.nvim_get_runtime_file("parser/" .. lang .. ".so", false)[1]
        .. "' -s '"
        .. file
        .. "' -c "
        .. comment_node
    )
  )
  local parser = parsers.get_parser(buf, lang)
  parser:parse(true)

  local self = highlighter.new(parser, {})

  assert.True(#assertions > 0, "No assertions detected!")
  for _, assertion in ipairs(assertions) do
    local row = assertion.position.row
    local col = assertion.position.column

    local captures = {}
    local highlights = {}
    self:prepare_highlight_states(row, row + 1)
    self:for_each_highlight_state(function(state)
      if not state.tstree then
        return
      end

      local root = state.tstree:root()
      local root_start_row, _, root_end_row, _ = root:range()

      -- Only worry about trees within the line range
      if root_start_row > row or root_end_row < row then
        return
      end

      local query = state.highlighter_query

      -- Some injected languages may not have highlight queries.
      if not query:query() then
        return
      end

      local iter = query:query():iter_captures(root, self.bufnr, row, row + 1)

      for capture, node, _ in iter do
        local hl = query:get_hl_from_capture(capture)
        assert.is.truthy(hl)

        assert.Truthy(node)
        assert.is.number(row)
        assert.is.number(col)
        if hl and ts.is_in_node_range(node, row, col) then
          local c = query._query.captures[capture] -- name of the capture in the query
          if c ~= nil and c ~= "spell" and c ~= "conceal" then
            captures[c] = true
            highlights[c] = true
          end
        end
      end
    end, true)
    if assertion.expected_capture_name:match "^!" then
      assert.Falsy(
        captures[assertion.expected_capture_name:sub(2)] or highlights[assertion.expected_capture_name:sub(2)],
        "Error in at "
          .. file
          .. ":"
          .. (row + 1)
          .. ":"
          .. (col + 1)
          .. ': expected "'
          .. assertion.expected_capture_name
          .. '", captures: '
          .. vim.inspect(vim.tbl_keys(captures))
          .. '", highlights: '
          .. vim.inspect(vim.tbl_keys(highlights))
      )
    else
      assert.True(
        captures[assertion.expected_capture_name] or highlights[assertion.expected_capture_name],
        "Error in at "
          .. file
          .. ":"
          .. (row + 1)
          .. ":"
          .. (col + 1)
          .. ': expected "'
          .. assertion.expected_capture_name
          .. '", captures: '
          .. vim.inspect(vim.tbl_keys(captures))
          .. '", highlights: '
          .. vim.inspect(vim.tbl_keys(highlights))
      )
    end
  end
end

describe("highlight queries", function()
  local files = vim.fn.split(vim.fn.glob "tests/query/highlights/**/*.*")
  for _, file in ipairs(files) do
    it(file, function()
      check_assertions(file)
    end)
  end
end)
