C++20 KMP

KMP 算法的最大特点是指示主串的指针不需回溯,整个匹配过程中,对主串仅需从头至尾扫描一遍,这对处理从外设输入的庞大文件很有效,可以边读边匹配,而无需回头重读。然而,我们从教材上及互联网上看到的大多数例子都是基于内存上字符串的查找,这里提供一个能够基于文件的查找的实现。

#include <fstream>
#include <iostream>
#include <iterator>
#include <string_view>
#include <vector>

template <std::forward_iterator PatIt, class SizeT = size_t,
          class NextT = std::vector<SizeT>>
class kmp_searcher {
 public:
  using size_type = SizeT;
  using next_table = NextT;
  static constexpr size_type npos = -1;

 public:
  kmp_searcher(PatIt first, PatIt last)
      : pat_first_{first},
        pat_last_{last},
        next_(std::distance(first, last), npos) {
    init_next_table();
  }

  template <std::forward_iterator FowardIt>
  std::pair<FowardIt, FowardIt> operator()(FowardIt first,
                                           FowardIt last) const {
    auto pos = do_search(first, last);
    if (pos == npos) {
      return std::make_pair(last, last);
    }
    return std::make_pair(std::next(first, pos), last);
  }

  template <std::input_iterator InputIt>
  size_type find(InputIt first, InputIt last) const {
    return do_search(first, last);
  }

  size_type size() const noexcept { return next_.size(); }

 private:
  void init_next_table() noexcept {
    auto it = pat_first_;
    auto j = npos;
    assert(pat_first_ != pat_last_);
    while (std::next(it) != pat_last_) {
      if (j == npos) {
        j = 0;
      } else if (*it == *std::next(pat_first_, j)) {
        ++j;
      } else {
        j = next_[j];
        continue;
      }
      std::advance(it, 1);
      next_[std::distance(pat_first_, it)] =
          (*it == *std::next(pat_first_, j) ? next_[j] : j);
    }
  }

  template <std::input_iterator InputIt>
  size_type do_search(InputIt first, InputIt last) const {
    size_type i = 0;
    size_type j = 0;
    assert(first != last);
    while (first != last && j != next_.size()) {
      if (j == npos) {
        ++first;
        ++i;
        j = 0;
      } else if (*first == *std::next(pat_first_, j)) {
        ++first;
        ++i;
        ++j;
      } else {
        j = next_[j];
      }
    }
    if (j == next_.size()) {
      return i - j;
    }
    return npos;
  }

 private:
  PatIt pat_first_;
  PatIt pat_last_;
  next_table next_;
};

template <class SizeT = size_t, std::input_iterator InputIt,
          std::forward_iterator PatIt>
SizeT kmp_find(InputIt first, InputIt last, PatIt pat_first, PatIt pat_last) {
  kmp_searcher<PatIt, SizeT> kmp(pat_first, pat_last);
  return kmp.find(first, last);
}

std::streamsize kmp_find(std::istream& istrm, std::string_view p,
                         std::streamsize pos = 0) {
  std::istreambuf_iterator<std::istream::char_type> first{istrm}, last;
  if (std::next(first, pos) == last) {
    return static_cast<std::streamsize>(-1);
  }
  auto result = kmp_find<std::streamsize>(first, last, p.cbegin(), p.cend());
  if (result == static_cast<std::streamsize>(-1)) {
    return result;
  }
  return result + pos;
}

template <class SizeT = size_t, std::input_iterator InputIt,
          std::forward_iterator PatIt>
SizeT kmp_count(InputIt first, InputIt last, PatIt pat_first, PatIt pat_last) {
  using searcher = kmp_searcher<PatIt, SizeT>;
  SizeT cnt = 0;
  searcher kmp(pat_first, pat_last);
  while (first != last) {
    auto pos = kmp.find(first, last);
    if (pos == searcher::npos) {
      return cnt;
    }
    ++cnt;
    if constexpr (!std::is_same_v<
                      typename std::iterator_traits<InputIt>::iterator_category,
                      std::input_iterator_tag>) {
      std::advance(first, pos + kmp.size());
    }
  }
  return cnt;
}

std::streamsize kmp_count(std::istream& istrm, std::string_view p,
                          std::streamsize pos = 0) {
  std::istreambuf_iterator<std::istream::char_type> first{istrm}, last;
  if (std::next(first, pos) == last) {
    return static_cast<std::streamsize>(-1);
  }
  return kmp_count<std::streamsize>(first, last, p.cbegin(), p.cend());
}

int main() {
  std::string_view p = "abc";
  kmp_searcher searcher(p.cbegin(), p.cend());

  std::string_view s = "123abcabc";
  auto it = std::search(s.cbegin(), s.cend(), searcher);
  std::cout << "std::search(s: " << s << ", p: " << p
            << "): " << std::string(it, s.cend()) << std::endl;

  // test.txt: 123abc456abcabc789abc
  std::string_view fn{"/Users/admin/dev/datastructure/test.txt"};
  {
    std::ifstream infile(fn);
    size_t pos = 13;
    std::cout << "kmp_find(fn: " << fn << ", p: " << p << ", pos: " << pos
              << "): " << kmp_find(infile, p, pos) << std::endl;
  }
  {
    std::ifstream infile(fn);
    size_t pos = 13;
    std::cout << "kmp_count(fn: " << fn << ", p:" << p << ", pos: " << pos
              << "): " << kmp_count(infile, p, pos) << std::endl;
  }
}

输出:

std::search(s: 123abcabc, p: abc): abcabc
kmp_find(fn: /Users/admin/dev/datastructure/test.txt, p: abc, pos: 13): 19
kmp_count(fn: /Users/admin/dev/datastructure/test.txt, p:abc, pos: 13): 3