rapid7/ruby_smb

View on GitHub
lib/ruby_smb/compression/lznt1.rb

Summary

Maintainability
C
1 day
Test Coverage
module RubySMB
  module Compression
    module LZNT1
      def self.compress(buf, chunk_size: 0x1000)
        out = ''
        until buf.empty?
          chunk = buf[0...chunk_size]
          compressed = compress_chunk(chunk)
          # chunk is compressed
          if compressed.length < chunk.length
            out << [0xb000 | (compressed.length - 1)].pack('v')
            out << compressed
          else
            out << [0x3000 | (chunk.length - 1)].pack('v')
            out << chunk
          end
          buf = buf[chunk_size..-1]
          break if buf.nil?
        end

        out
      end

      def self.compress_chunk(chunk)
        blob = chunk
        out = ''
        pow2 = 0x10
        l_mask3 = 0x1002
        o_shift = 12
        until blob.empty?
          bits = 0
          tmp = ''
          i = -1
          loop do
            i += 1
            bits >>= 1
            while pow2 < (chunk.length - blob.length)
              pow2 <<= 1
              l_mask3 = (l_mask3 >> 1) + 1
              o_shift -= 1
            end

            max_len = [blob.length, l_mask3].min
            offset, length = find(chunk[0...(chunk.length - blob.length)], blob, max_len)

            # try to find more compressed pattern
            _offset2, length2 = find(chunk[0...chunk.length - blob.length + 1], blob[1..-1], max_len)

            length = 0 if length < length2

            if length > 0
              symbol = ((offset - 1) << o_shift) | (length - 3)
              tmp << [symbol].pack('v')
              # set the highest bit
              bits |= 0x80
              blob = blob[length..-1]
            else
              tmp += blob[0]
              blob = blob[1..-1]
            end

            break if blob.empty? || i == 7
          end

          out << [bits >> (7 - i)].pack('C')
          out << tmp
        end

        out
      end

      def self.decompress(buf, length_check: true)
        out = ''
        until buf.empty?
          header = buf.unpack1('v')
          length = (header & 0xfff) + 1
          raise EncodingError, 'invalid chunk length' if length_check && length > (buf.length - 2)

          chunk = buf[2...length + 2]
          out << if header & 0x8000 == 0
                   chunk
                 else
                   decompress_chunk(chunk)
                 end
          buf = buf[length + 2..-1]
        end

        out
      end

      def self.decompress_chunk(chunk)
        out = ''
        until chunk.empty?
          flags = chunk[0].unpack1('C')
          chunk = chunk[1..-1]
          8.times do |i|
            if (flags >> i & 1) == 0
              out << chunk[0]
              chunk = chunk[1..-1]
            else
              flag = chunk.unpack1('v')
              pos = out.length - 1
              l_mask = 0xfff
              o_shift = 12
              while pos >= 0x10
                l_mask >>= 1
                o_shift -= 1
                pos >>= 1
              end

              length = (flag & l_mask) + 3
              offset = (flag >> o_shift) + 1

              if length >= offset
                out_offset = out[-offset..-1]
                tmp = out_offset * (0xfff / out_offset.length + 1)
                out << tmp[0...length]
              else
                out << out[-offset..-offset + length - 1]
              end
              chunk = chunk[2..-1]
            end
            break if chunk.empty?
          end
        end

        out
      end

      class << self
        private

        def find(src, target, max_len)
          result_offset = 0
          result_length = 0
          1.upto(max_len - 1) do |i|
            offset = src.rindex(target[0...i])
            next if offset.nil?

            tmp_offset = src.length - offset
            tmp_length = i
            if tmp_offset == tmp_length
              src_offset = src[offset..-1]
              tmp = src_offset * (0xfff / src_offset.length + 1)
              i.upto(max_len) do |j|
                offset = tmp.rindex(target[0...j])
                break if offset.nil?

                tmp_length = j
              end
            end

            if tmp_length > result_length
              result_offset = tmp_offset
              result_length = tmp_length
            end
          end

          result_length < 3 ? [0, 0] : [result_offset, result_length]
        end
      end
    end
  end
end