#!/usr/bin/python
# Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import argparse
import contextlib
import json
import os
import re
import subprocess
import sys
import tempfile
import time
import uuid

# First sector we can use.
GPT_RESERVED_SECTORS = 34


class ConfigNotFound(Exception):
  pass
class PartitionNotFound(Exception):
  pass
class InvalidLayout(Exception):
  pass


def LoadPartitionConfig(options):
  """Loads a partition tables configuration file into a Python object.

  Args:
    options: Flags passed to the script
  Returns:
    Object containing disk layout configuration
  """

  valid_keys = set(('_comment', 'metadata', 'layouts'))
  valid_layout_keys = set((
      '_comment', 'type', 'num', 'label', 'blocks', 'block_size', 'fs_blocks',
      'fs_block_size', 'fs_type', 'features', 'uuid', 'part_alignment', 'mount',
      'binds', 'fs_subvolume', 'fs_bytes_per_inode', 'fs_inode_size', 'fs_label'))
  integer_layout_keys = set((
      'blocks', 'block_size', 'fs_blocks', 'fs_block_size', 'part_alignment',
      'fs_bytes_per_inode', 'fs_inode_size'))
  required_layout_keys = set(('type', 'num', 'label', 'blocks'))

  filename = options.disk_layout_file
  if not os.path.exists(filename):
    raise ConfigNotFound('Partition config %s was not found!' % filename)
  with open(filename) as f:
    config = json.load(f)

  unknown_keys = set(config.keys()) - valid_keys
  if unknown_keys:
    raise InvalidLayout('Unknown items: %s' % ' '.join(unknown_keys))

  try:
    metadata = config['metadata']
    base = config['layouts']['base']
    for key in ('part_alignment', 'disk_alignment',
        'block_size', 'fs_block_size'):
      metadata[key] = int(metadata[key])
  except KeyError as e:
    raise InvalidLayout('Metadata is missing required entries: %s' % e)

  def VerifyLayout(layout_name, layout, base=None):
    for part_num, part in layout.iteritems():
      part['num'] = int(part_num)
      part_keys = set(part.iterkeys())
      unknown_keys = part_keys - valid_layout_keys
      if unknown_keys:
        raise InvalidLayout('Unknown items in partition %s %s: %r' %
                            (layout_name, part_num, ' '.join(unknown_keys)))

      for int_key in integer_layout_keys.intersection(part_keys):
        part[int_key] = int(part[int_key])

      if base:
        part_base = base.get(part_num, {})
        part_keys.update(part_base.iterkeys())

      if part.get('type', None) == 'blank':
        continue

      missing_keys = required_layout_keys - part_keys
      if missing_keys:
        raise InvalidLayout('Missing items in partition %s %s: %s' %
                            (layout_name, part_num, ' '.join(missing_keys)))

      if 'uuid' in part:
        try:
          # double check the string formatting
          part['uuid'] = str(uuid.UUID(part['uuid']))
        except ValueError as e:
          raise InvalidLayout('Invalid uuid %r: %s' % (part['uuid'], e))

      if 'fs_type' in part:
        if part['fs_type'] not in ('btrfs', 'ext2', 'ext4', 'vfat'):
          raise InvalidLayout('Invalid fs_type: %r' % part['fs_type'])

      if 'fs_subvolume' in part:
        if part.get('fs_type', None) != 'btrfs':
          raise InvalidLayout('Invalid fs: only btrfs supports subvolumes')

      if 'fs_bytes_per_inode' in part or 'fs_inode_size' in part:
        if part.get('fs_type', None) not in ('ext2', 'ext4'):
          raise InvalidLayout('Invalid fs: only extX supports inode options')

  def Align(count, alignment):
    offset = count % alignment
    if offset:
      count += alignment - offset
    return count


  def FillExtraValues(layout_name, layout, base=None):
    # Reserved size for first GPT
    disk_block_count = GPT_RESERVED_SECTORS

    # Fill in default values from base,
    # dict doesn't have a update()+setdefault() method so this looks tedious
    if base:
      for part_num, base_part in base.iteritems():
        part = layout.setdefault(part_num, {})
        for base_key, base_value in base_part.iteritems():
          part.setdefault(base_key, base_value)

    for part_num, part in sorted(layout.iteritems(), key=lambda t: int(t[0])):
      if part['type'] == 'blank':
        continue

      part.setdefault('part_alignment', metadata['part_alignment'])
      part['bytes'] = part['blocks'] * metadata['block_size']
      part.setdefault('fs_block_size', metadata['fs_block_size'])
      part.setdefault('fs_blocks', part['bytes'] // part['fs_block_size'])
      part['fs_bytes'] = part['fs_blocks'] * part['fs_block_size']

      if part['fs_bytes'] > part['bytes']:
        raise InvalidLayout(
            'Filesystem may not be larger than partition: %s %s: %d > %d' %
            (layout_name, part_num, part['fs_bytes'], part['bytes']))

      disk_block_count = Align(disk_block_count, part['part_alignment'])
      part['first_block'] = disk_block_count
      part['first_byte'] = disk_block_count * metadata['block_size']
      disk_block_count += part['blocks']

      part.setdefault('uuid', str(uuid.uuid4()))

    # Reserved size for second GPT plus align disk image size
    disk_block_count += GPT_RESERVED_SECTORS
    disk_block_count = Align(disk_block_count, metadata['disk_alignment'])

    # If this is the requested layout stash the disk size into the global
    # metadata. Kinda odd but the best place I've got with this data structure.
    if layout_name == options.disk_layout:
      metadata['blocks'] = disk_block_count
      metadata['bytes'] = disk_block_count * metadata['block_size']


  # Verify 'base' before other layouts because it is inherited by the others
  # Fill in extra/default values in base last so they aren't inherited
  VerifyLayout('base', base)
  for layout_name, layout in config['layouts'].iteritems():
    if layout_name == 'base':
      continue
    VerifyLayout(layout_name, layout, base)
    FillExtraValues(layout_name, layout, base)
  FillExtraValues('base', base)

  return config, config['layouts'][options.disk_layout]


def GetPartitionTableFromConfig(options):
  """Loads a partition table and returns a given partition table type

  Args:
    options: Flags passed to the script
  Returns:
    A list defining all known partitions.
  """

  config, partitions = LoadPartitionConfig(options)
  return partitions


def GetPartitionTableFromImage(options, config, partitions):
  """Loads very basic partition table info from an existing image.

  Currently only includes blocks and first_block values.

  Args:
    options: Flags passed to the script
  Returns:
    A list defining all existing partitions.
  """
  block_size = config['metadata']['block_size']
  cgpt_show = subprocess.check_output(
          ['cgpt', 'show', '-q', options.disk_image])
  for line in cgpt_show.split('\n'):
    if not line.strip():
      continue
    fields = line.split(None, 3)
    if len(fields) != 4 or not all(f.isdigit() for f in fields[:3]):
      raise Exception('Invalid output from cgpt show -q: %r' % line)

    part = partitions.setdefault(fields[2], {})
    part['image_first_block'] = int(fields[0])
    part['image_first_byte'] = int(fields[0]) * block_size
    part['image_blocks'] = int(fields[1])
    part['image_bytes'] = int(fields[1]) * block_size

    # Pre-compute whether the image and config are compatible.
    # The image is compatible with the config if each partition:
    # - starts at the same position
    # - is the same size or larger in th layout config

    part['image_exists'] = True
    if part.get('type', 'blank') == 'blank':
      part['image_compat'] = False
    elif part['first_block'] == part['image_first_block']:
      part['image_compat'] = part['blocks'] >= part['image_blocks']
    else:
      part['image_compat'] = False

  for part in partitions.itervalues():
    if part.get('type', 'blank') == 'blank':
      continue
    if not part.get('image_exists', False):
      continue
    if not part.get('fs_type', None):
      continue
    with PartitionLoop(options, part) as loop_dev:
      try:
        part['image_fs_type'] = subprocess.check_output(
          ['sudo', 'blkid', '-o', 'value', '-s', 'TYPE', loop_dev]).strip()
      except subprocess.CalledProcessError:
        part['image_fs_type'] = None

  # Set compat flags for any partition not in the image
  for part in partitions.itervalues():
    part.setdefault('image_exists', False)
    if part.get('type', 'blank') == 'blank':
      part.setdefault('image_compat', True)
    else:
      part.setdefault('image_compat', False)


def WritePartitionTable(options, config=None, partitions=None):
  """Writes the given partition table to a disk image or device.

  Args:
    options: Flags passed to the script
    config: Complete layout configuration file object
    partitions: Selected layout configuration object
  """

  def Cgpt(*args):
    subprocess.check_call(['cgpt'] + [str(a) for a in args])

  if not (config and partitions):
    config, partitions = LoadPartitionConfig(options)

  if options.create:
    Cgpt('create', '-c', '-s', config['metadata']['blocks'],
        '-g', '00000000-0000-0000-0000-000000000001', options.disk_image)
  else:
    # If we are not creating a fresh image all partitions must be compatible.
    GetPartitionTableFromImage(options, config, partitions)
    if not all(p['image_compat'] for p in partitions.itervalues()):
      raise InvalidLayout("New disk layout is incompatible existing image")

    # Extend the disk image size as needed
    with open(options.disk_image, 'r+') as image_fd:
      image_fd.truncate(config['metadata']['bytes'])
    Cgpt('repair', options.disk_image)

  hybrid = None
  prioritize = []
  for partition in partitions.itervalues():
    if partition['type'] != 'blank':
      Cgpt('add', '-i', partition['num'],
                  '-b', partition['first_block'],
                  '-s', partition['blocks'],
                  '-t', partition['type'],
                  '-l', partition['label'],
                  '-u', partition['uuid'],
                  options.disk_image)

      features = partition.get('features', [])
      if not hybrid and 'hybrid' in features:
        hybrid = partition['num']
      if 'prioritize' in features:
        prioritize.append(partition)

  if hybrid:
    # Enable legacy boot flag and generate a hybrid MBR partition table
    Cgpt('add', '-i', hybrid, '-B1', options.disk_image)

  prioritize.reverse()
  for i, partition in enumerate(prioritize):
    Cgpt('add', '-i', partition['num'], '-S1', '-P', i+1, options.disk_image)

  Cgpt('show', options.disk_image)


def Sudo(cmd, stdout_null=False):
  """Little wrapper around sudo with support for redirecting to /dev/null

  Some tools like tune2fs don't have a quiet mode which just adds
  useless noise to our build output, drowning out what may be more
  interesting news.

  Args:
    cmd: a command and arguments to run.
    stdout_null: bool to enable redirecting stdout to /dev/null.
  """

  null = None
  if stdout_null:
    null = open('/dev/null', 'w')

  try:
    subprocess.check_call(['sudo'] + [str(c) for c in cmd], stdout=null)
  finally:
    if null:
      null.close()

def SudoOutput(cmd):
  """Wrapper around sudo which returns the command's output

  We use this when parsing the root hash of a partition from veritysetup output.

  Args:
    cmd: a command and arguments to run.

  Returns:
    A bytestring of the command's output
  """

  output = subprocess.check_output(['sudo'] + [str(c) for c in cmd])

  return output


def BtrfsSubvolId(path):
  """Get the subvolume id from a given path."""

  out = subprocess.check_output(
          ['sudo', 'btrfs', 'subvolume', 'show', path])
  m = re.search(r'^\s*Object ID:\s*(\d+)$', out, re.MULTILINE)
  if not m:
    raise Exception('Failed to parse btrfs output: %r', out)

  return int(m.group(1))


def FormatBtrfs(part, device):
  """Format a btrfs filesystem.

  Args:
    part: dict defining the partition
    device: name of the block device to format
  """
  cmd = ['mkfs.btrfs', '--byte-count', part['fs_bytes']]
  if 'fs_label' in part:
    cmd += ['--label', part['fs_label']]
  Sudo(cmd + [device])

  if part.get('fs_subvolume', None):
    btrfs_mount = tempfile.mkdtemp()
    subvol_path = '%s/%s' % (btrfs_mount, part['fs_subvolume'])
    Sudo(['mount', '-t', 'btrfs', device, btrfs_mount])
    try:
      Sudo(['btrfs', 'subvolume', 'create', subvol_path])
      subvol_id = BtrfsSubvolId(subvol_path)
      Sudo(['btrfs', 'subvolume', 'set-default', subvol_id, btrfs_mount])
    finally:
      Sudo(['umount', btrfs_mount])
      os.rmdir(btrfs_mount)


def FormatExt(part, device):
  """Format an ext2 or ext4 filesystem.

  Args:
    part: dict defining the partition
    device: name of the block device to format
  """
  Sudo(['mke2fs', '-q',
                  '-t', part['fs_type'],
                  '-b', part['fs_block_size'],
                  '-i', part.get('fs_bytes_per_inode', part['fs_block_size']),
                  '-I', part.get('fs_inode_size', 128),
                  device,
                  part['fs_blocks']])

  tune_cmd = ['tune2fs', '-e', 'remount-ro']

  if 'fs_label' in part:
    tune_cmd += ['-L', part['fs_label']]

  if part['type'] == 'flatcar-usr':
    tune_cmd += ['-U', 'clear',
                 '-T', '20091119110000',
                 '-c', '0', '-i', '0', # Disable auto fsck
                 '-m', '0', '-r', '0'] # Disable reserve blocks

  Sudo(tune_cmd + [device], stdout_null=True)


def FormatFat(part, device):
  """Format a FAT filesystem.

  Args:
    part: dict defining the partition
    device: name of the block device to format
  """
  # The block-count argument to mkfs.vfat is in units of 1k
  vfat_block_size = 1024
  vfat_blocks = part['bytes'] // vfat_block_size

  cmd = ['mkfs.vfat', '-I']
  if 'fs_label' in part:
    cmd += ['-n', part['fs_label']]
  if part['type'] == 'efi':
    # ESP is FAT32 irrespective of size
    cmd += ['-F', '32']
  Sudo(cmd + [device, vfat_blocks], stdout_null=True)


@contextlib.contextmanager
def PartitionLoop(options, partition):
  """Allocate (and automatically free) loop devices for a partition."""

  for i in range(0,5):
    try:
      loop_dev = subprocess.check_output(['sudo', 'losetup',
        '--offset', str(partition['first_byte']),
        '--sizelimit', str(partition['bytes']),
        '--find', '--show', options.disk_image])
      loop_dev = loop_dev.strip()
      err = None
      break
    except subprocess.CalledProcessError as error:
      print("Failed to set up loopback, attempt %d" % i)
      err = error
      time.sleep(5)

  if err is not None:
    raise err

  try:
    yield loop_dev
  finally:
    Sudo(['losetup', '--detach', loop_dev])


def FormatPartition(options, part):
  print "Formatting partition %s (%s) as %s" % (
          part['num'], part['label'], part['fs_type'])

  with PartitionLoop(options, part) as loop_dev:
    if part['fs_type'] in ('ext2', 'ext4'):
      FormatExt(part, loop_dev)
    elif part['fs_type'] == 'btrfs':
      FormatBtrfs(part, loop_dev)
    elif part['fs_type'] == 'vfat':
      FormatFat(part, loop_dev)
    else:
      raise Exception("Unhandled fs type %s" % part['fs_type'])


def Format(options):
  """Writes the given partition table and initialize fresh filesystems.

  Args:
    options: Flags passed to the script
  """

  # Note on using sudo: We don't really need to do this stuff as root
  # but mke2fs and friends doesn't have an option to make filesystems at
  # arbitrary offsets but using loop devices makes that possible.

  config, partitions = LoadPartitionConfig(options)
  WritePartitionTable(options, config, partitions)

  for part in partitions.itervalues():
    if part['type'] == 'blank' or 'fs_type' not in part:
      continue

    FormatPartition(options, part)


def ResizeExt(part, device):
  """Resize ext[234] filesystems.

  Args:
    part: dict defining the partition
    device: name of the block device
  """
  Sudo(['e2fsck', '-p', '-f', device])
  Sudo(['resize2fs', device, str(part['fs_blocks'])])


def ResizeBtrfs(part, device):
  """Resize btrfs filesystems.

  Args:
    part: dict defining the partition
    device: name of the block device
  """
  btrfs_mount = tempfile.mkdtemp()
  Sudo(['mount', '-t', 'btrfs', device, btrfs_mount])
  try:
    Sudo(['btrfs', 'filesystem', 'resize', part['fs_bytes'], btrfs_mount])
  finally:
    Sudo(['umount', btrfs_mount])
    os.rmdir(btrfs_mount)


def Update(options):
  """Writes the given partition table, resize filesystems, and
     format free partitions.

  Args:
    options: Flags passed to the script
  """

  config, partitions = LoadPartitionConfig(options)
  WritePartitionTable(options, config, partitions)

  for part in partitions.itervalues():
    if not part.get('fs_type', None):
      continue
    elif not part['image_fs_type']:
      FormatPartition(options, part)

  for part in partitions.itervalues():
    resize_func = None
    if not part.get('fs_type', None):
      continue
    elif part['bytes'] == part['image_bytes']:
      continue
    elif part['fs_type'] in ('ext2', 'ext4') and IsE2fsReadWrite(options, part):
      resize_func = ResizeExt
    elif part.get('fs_type', None) == 'btrfs':
      resize_func = ResizeBtrfs
    else:
      continue

    print "Resizing partition %s (%s) to %s bytes" % (
            part['num'], part['label'], part['fs_bytes'])

    with PartitionLoop(options, part) as loop_dev:
      resize_func(part, loop_dev)


def Mount(options):
  """Mount the given disk image.

  The existing partition table is used to determine what exists but the
  disk layout config is used to look up mount points and binds.

  Args:
    options: Flags passed to the script
  """

  config, partitions = LoadPartitionConfig(options)
  GetPartitionTableFromImage(options, config, partitions)
  mounts = {}

  for part_num, part in partitions.iteritems():
    path = part.get('mount', None)
    if not path or not path.startswith('/'):
      continue
    if not part.get('image_exists', False):
      continue

    mounts[path] = part

  if '/' not in mounts:
    raise InvalidLayout('No partition defined to mount on /')

  def DoMount(mount):
    full_path = os.path.realpath(options.mount_dir + mount['mount'])
    mount_opts = ['loop',
                  'offset=%d' % mount['image_first_byte'],
                  'sizelimit=%d' % mount['image_bytes']]
    if options.read_only:
      mount_opts.append('ro')
    elif (mount.get('fs_type', None) in ('ext2', 'ext4') and
          not IsE2fsReadWrite(options, mount)):
      mount_opts.append('ro')

    if mount.get('fs_subvolume', None):
      mount_opts.append('subvol=%s' % mount['fs_subvolume'])

    Sudo(['mkdir', '-p', full_path])
    # This tends to fail, retry if it does
    err = None
    for i in range(0,5):
      try:
        Sudo(['mount', '-t', mount.get('fs_type', 'auto'),
                       '-o', ','.join(mount_opts),
                       options.disk_image, full_path])
        err = None
        break
      except subprocess.CalledProcessError as e:
        print("Error mounting %s, attempt %d" % (full_path, i))
        err = e
        time.sleep(5)

    if err is not None:
      raise err

    for src, dst in mount.get('binds', {}).iteritems():
      # src may be relative or absolute, os.path.join handles this.
      full_src =  os.path.realpath(
              options.mount_dir + os.path.join(mount['mount'], src))
      full_dst = os.path.realpath(options.mount_dir + dst)
      Sudo(['mkdir', '-p', full_src, full_dst])
      Sudo(['mount', '--bind', full_src, full_dst])

  for mount in sorted(mounts, key=len):
    DoMount(mounts[mount])

def Umount(options):
  """Unmount the given path.

  Args:
    options: Flags passed to the script
  """
  Sudo(['umount', '--recursive', '--detach-loop', options.mount_dir])


def Tune2fsReadWrite(options, partition, disable_rw):
  """Enable/Disable read-only hack.

  From common.sh:
  This helper clobbers the ro compat value in our root filesystem.

  When the system is built with --enable_rootfs_verification, bit-precise
  integrity checking is performed.  That precision poses a usability issue on
  systems that automount partitions with recognizable filesystems, such as
  ext2/3/4.  When the filesystem is mounted 'rw', ext2 metadata will be
  automatically updated even if no other writes are performed to the
  filesystem.  In addition, ext2+ does not support a "read-only" flag for a
  given filesystem.  That said, forward and backward compatibility of
  filesystem features are supported by tracking if a new feature breaks r/w or
  just write compatibility.  We abuse the read-only compatibility flag[1] in
  the filesystem header by setting the high order byte (le) to FF.  This tells
  the kernel that features R24-R31 are all enabled.  Since those features are
  undefined on all ext-based filesystem, all standard kernels will refuse to
  mount the filesystem as read-write -- only read-only[2].

  [1] 32-bit flag we are modifying:
   http://git.chromium.org/cgi-bin/gitweb.cgi?p=kernel.git;a=blob;f=include/linux/ext2_fs.h#l417
  [2] Mount behavior is enforced here:
   http://git.chromium.org/cgi-bin/gitweb.cgi?p=kernel.git;a=blob;f=fs/ext2/super.c#l857

  N.B., if the high order feature bits are used in the future, we will need to
        revisit this technique.

  Args:
    options: Flags passed to the script
    partition: Config for partition to manipulate
    disable_rw: Set to true to disable read-write access
  """

  if disable_rw:
    print "Disabling read-write mounting of partition %s (%s)" % (
            partition['num'], partition['label'])
  else:
    print "Enabling read-write mounting of partition %s (%s)" % (
            partition['num'], partition['label'])

  # offset of ro_compat, highest order byte (le 32 bit field)
  flag_offset = 0x464 + 3
  flag_value = 0xff if disable_rw else 0x00
  with open(options.disk_image, 'r+') as image:
    image.seek(partition['first_byte'] + flag_offset)
    image.write(chr(flag_value))


def IsE2fsReadWrite(options, partition):
  """Returns True if FS is read-write, False if hack is active.

  Args:
    options: Flags passed to the script
    partition: Config for partition to query
  """

  # offset of ro_compat, highest order byte (le 32 bit field)
  flag_offset = 0x464 + 3
  with open(options.disk_image, 'r') as image:
    image.seek(partition['first_byte'] + flag_offset)
    flag_value = image.read(1)

  return ord(flag_value) == 0


def Tune(options):
  """Tweak some filesystem options.

  Args:
    options: Flags passed to the script
  """

  config, partitions = LoadPartitionConfig(options)
  GetPartitionTableFromImage(options, config, partitions)
  part = GetPartition(partitions, options.partition)

  if not part['image_compat']:
    raise InvalidLayout("Disk layout is incompatible with existing image")

  if options.disable2fs_rw is not None:
    if part.get('fs_type', None) not in ('ext2', 'ext4'):
      raise Exception("Partition %s is not a ext2 or ext4" % options.partition)
    Tune2fsReadWrite(options, part, options.disable2fs_rw)
  else:
    raise Exception("No options specified!")


def Verity(options):
  """Hash verity protected filesystems.

  Args:
    options: Flags passed to the script
  """

  config, partitions = LoadPartitionConfig(options)
  GetPartitionTableFromImage(options, config, partitions)

  for part_num, part in partitions.iteritems():
    if 'verity' not in part.get('features', []):
      continue

    if not part['image_compat']:
      raise InvalidLayout("Disk layout is incompatible with existing image")

    if part.get('fs_type', None) in ('ext2', 'ext4'):
      Tune2fsReadWrite(options, part, disable_rw=True)

    with PartitionLoop(options, part) as loop_dev:
      verityout = SudoOutput(['veritysetup', 'format', '--hash=sha256',
                  '--data-block-size', part['fs_block_size'],
                  '--hash-block-size', part['fs_block_size'],
                  '--data-blocks', part['fs_blocks'],
                  '--hash-offset', part['fs_bytes'],
                  loop_dev, loop_dev])
      print(verityout.strip())
      m = re.search("Root hash:\s+([a-f0-9]{64})$", verityout, re.IGNORECASE|re.MULTILINE)
      if not m:
          raise Exception("Failed to parse verity output!")

      if options.root_hash != None:
          with open(options.root_hash, "w") as hash_file:
              hash_file.write(m.group(1))
              hash_file.write("\n")


def Extract(options):
  """Write a single partition out to its own image file.

  Args:
    options: Flags passed to the script
  """

  config, partitions = LoadPartitionConfig(options)
  GetPartitionTableFromImage(options, config, partitions)
  part = GetPartition(partitions, options.partition)

  if not part['image_compat']:
    raise InvalidLayout("Disk layout is incompatible with existing image")

  subprocess.check_call(['dd',
                         'bs=10MB',
                         'iflag=count_bytes,skip_bytes',
                         'conv=sparse',
                         'status=none',
                         'if=%s' % options.disk_image,
                         'of=%s' % options.output,
                         'skip=%s' % part['image_first_byte'],
                         'count=%s' % part['image_bytes']])


def GetPartitionByNumber(partitions, num):
  """Given a partition table and number returns the partition object.

  Args:
    partitions: List of partitions to search in
    num: Number of partition to find
  Returns:
    An object for the selected partition
  """

  partition = partitions.get(str(num), None)
  if not partition or partition['type'] == 'blank':
    raise PartitionNotFound('Partition not found')

  return partition


def GetPartitionByLabel(partitions, label):
  """Given a partition table and label returns the partition object.

  Args:
    partitions: List of partitions to search in
    label: Label of partition to find
  Returns:
    An object for the selected partition
  """
  for partition in partitions.itervalues():
    if partition['type'] == 'blank':
      continue
    elif partition['label'] == label:
      return partition

  raise PartitionNotFound('Partition not found')


def GetPartition(partitions, part_id):
  """Given a partition table and label or num returns the partition object.

  Args:
    partitions: List of partitions to search in
    part_id: Label or number of partition to find
  Returns:
    An object for the selected partition
  """
  if str(part_id).isdigit():
    return GetPartitionByNumber(partitions, part_id)
  else:
    return GetPartitionByLabel(partitions, part_id)


def GetBlockSize(options):
  """Returns the partition table block size.

  Args:
    options: Flags passed to the script
  Prints:
    Block size of all partitions in the layout
  """

  config, partitions = LoadPartitionConfig(options)
  print config['metadata']['block_size']


def GetFilesystemBlockSize(options):
  """Returns the filesystem block size.

  This is used for all partitions in the table that have filesystems.

  Args:
    options: Flags passed to the script
  Prints:
    Block size of all filesystems in the layout
  """

  config, partitions = LoadPartitionConfig(options)
  print config['metadata']['fs_block_size']


def GetPartitionSize(options):
  """Returns the partition size of a given partition for a given layout type.

  Args:
    options: Flags passed to the script
  Prints:
    Size of selected partition in bytes
  """

  partitions = GetPartitionTableFromConfig(options)
  partition = GetPartitionByNumber(partitions, options.partition_num)
  print partition['bytes']


def GetFilesystemSize(options):
  """Returns the filesystem size of a given partition for a given layout type.

  If no filesystem size is specified, returns the partition size.

  Args:
    options: Flags passed to the script
  Prints:
    Size of selected partition filesystem in bytes
  """

  partitions = GetPartitionTableFromConfig(options)
  partition = GetPartitionByNumber(partitions, options.partition_num)
  print partition.get('fs_bytes', partition['bytes'])


def GetLabel(options):
  """Returns the label for a given partition.

  Args:
    options: Flags passed to the script
  Prints:
    Label of selected partition, or 'UNTITLED' if none specified
  """

  partitions = GetPartitionTableFromConfig(options)
  partition = GetPartitionByNumber(partitions, options.partition_num)
  print partition.get('label', 'UNTITLED')


def GetNum(options):
  """Returns the number for a given label.

  Args:
    options: Flags passed to the script
  Prints:
    Number of selected partition, or '-1' if there is no number
  """

  partitions = GetPartitionTableFromConfig(options)
  partition = GetPartitionByLabel(partitions, options.label)
  print partition.get('num', '-1')


def GetUuid(options):
  """Returns the unique partition UUID for a given label.

  Args:
    options: Flags passed to the script
  Prints:
    String containing the requested UUID
  """

  partitions = GetPartitionTableFromConfig(options)
  partition = GetPartitionByLabel(partitions, options.label)
  print partition.get('uuid', '')


def DoDebugOutput(options):
  """Prints out a human readable disk layout in on-disk order.

  This will round values larger than 1MB, it's exists to quickly
  visually verify a layout looks correct.

  Args:
    options: Flags passed to the script
  """
  partitions = GetPartitionTableFromConfig(options)

  for num, partition in sorted(partitions.iteritems()):
    if partition['type'] != 'blank':
      if partition['bytes'] < 1024 * 1024:
        size = '%d bytes' % partition['bytes']
      else:
        size = '%d MB' % (partition['bytes'] / 1024 / 1024)
      if 'fs_bytes' in partition:
        if partition['fs_bytes'] < 1024 * 1024:
          fs_size = '%d bytes' % partition['fs_bytes']
        else:
          fs_size = '%d MB' % (partition['fs_bytes'] / 1024 / 1024)
        print '%s: %s - %s/%s' % (num, partition['label'], fs_size, size)
      else:
        print '%s: %s - %s' % (num, partition['label'], size)
    else:
      print '%s: blank' % num


def DoParseOnly(options):
  """Parses a layout file only, used before reading sizes to check for errors.

  Args:
    options: Flags passed to the script
  """
  GetPartitionTableFromConfig(options)


def main(argv):
  default_layout_file = os.environ.get('DISK_LAYOUT_FILE',
          os.path.join(os.path.dirname(__file__), 'disk_layout.json'))
  default_layout_type = os.environ.get('DISK_LAYOUT_TYPE', 'base')

  parser = argparse.ArgumentParser(
          formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  parser.add_argument('--disk_layout_file', default=default_layout_file,
          help='path to disk layout json file')
  parser.add_argument('--disk_layout', default=default_layout_type,
          help='disk layout type from the json file')
  actions = parser.add_subparsers(title='actions')

  a = actions.add_parser('write_gpt', help='write/update partition table')
  a.add_argument('--create', action='store_true', default=True,
          help='initialize new partition table')
  a.add_argument('--update', action='store_false', dest='create',
          help='update existing partition table')
  a.add_argument('disk_image', help='path to disk image file')
  a.set_defaults(func=WritePartitionTable)

  a = actions.add_parser('format', help='write gpt and filesystems to image')
  a.add_argument('disk_image', help='path to disk image file')
  a.set_defaults(func=Format, create=True)

  a = actions.add_parser('update',
          help='write gpt, resize filesystems, and format free partitions')
  a.add_argument('disk_image', help='path to disk image file')
  a.set_defaults(func=Update, create=False)

  a = actions.add_parser('mount', help='mount filesystems in image')
  a.add_argument('--read_only', '-r', action='store_true',
          help='mount filesystems read-only')
  a.add_argument('disk_image', help='path to disk image file')
  a.add_argument('mount_dir', help='path to root filesystem mount point')
  a.set_defaults(func=Mount)

  a = actions.add_parser('umount', help='unmount a image mount point')
  a.add_argument('mount_dir', help='path to root filesystem mount point')
  a.set_defaults(func=Umount)

  a = actions.add_parser('tune', help='tweak filesystem options')
  a.add_argument('--disable2fs_rw', action='store_true', default=None,
          help='disable mounting ext2 filesystems read-write')
  a.add_argument('--enable2fs_rw', action='store_false', dest='disable2fs_rw',
          help='re-enable mounting ext2 filesystems read-write')
  a.add_argument('disk_image', help='path to disk image file')
  a.add_argument('partition', help='number or label of partition to edit')
  a.set_defaults(func=Tune)

  a = actions.add_parser('verity', help='compute verity hashes')
  a.add_argument('disk_image', help='path to disk image file')
  a.add_argument('--root_hash', help='name of file to contain root hash')
  a.set_defaults(func=Verity)

  a = actions.add_parser('extract', help='extract a single partition')
  a.add_argument('disk_image', help='path to disk image file')
  a.add_argument('partition', help='number or label of partition to edit')
  a.add_argument('output', help='path to write the partition image to')
  a.set_defaults(func=Extract)

  a = actions.add_parser('readblocksize', help='get device block size')
  a.set_defaults(func=GetBlockSize)

  a = actions.add_parser('readfsblocksize', help='get filesystem block size')
  a.set_defaults(func=GetFilesystemBlockSize)

  a = actions.add_parser('readpartsize', help='get partition size')
  a.add_argument('partition_num', type=int, help='partition number')
  a.set_defaults(func=GetPartitionSize)

  a = actions.add_parser('readfssize', help='get filesystem size')
  a.add_argument('partition_num', type=int, help='partition number')
  a.set_defaults(func=GetFilesystemSize)

  a = actions.add_parser('readlabel', help='get partition label')
  a.add_argument('partition_num', type=int, help='partition number')
  a.set_defaults(func=GetLabel)

  a = actions.add_parser('readnum', help='get partition number')
  a.add_argument('label', help='partition label')
  a.set_defaults(func=GetNum)

  a = actions.add_parser('readuuid', help='get partition uuid')
  a.add_argument('label', help='partition label')
  a.set_defaults(func=GetUuid)

  a = actions.add_parser('debug', help='dump debug output')
  a.set_defaults(func=DoDebugOutput)

  a = actions.add_parser('parseonly', help='validate config')
  a.set_defaults(func=DoParseOnly)

  options = parser.parse_args(argv[1:])
  options.func(options)


if __name__ == '__main__':
  main(sys.argv)
