GITENV file updated 2

parent ef9553d1
# -*- coding: utf-8 -*-
"""
args
~~~~
This module provides the CLI argument interface for clint.
"""
import os
from sys import argv
from glob import glob
from collections import OrderedDict
def _expand_path(path):
"""Expands directories and globs in given path."""
paths = []
path = os.path.expanduser(path)
path = os.path.expandvars(path)
if os.path.isdir(path):
for (dir, dirs, files) in os.walk(path):
for file in files:
paths.append(os.path.join(dir, file))
else:
paths.extend(glob(path))
return paths
def _is_collection(obj):
"""Tests if an object is a collection. Strings don't count."""
if isinstance(obj, basestring):
return False
return hasattr(obj, '__getitem__')
class ArgsList(object):
"""CLI Argument management."""
def __init__(self, args=None, no_argv=False):
if not args:
if not no_argv:
self._args = argv[1:]
else:
self._args = []
else:
self._args = args
def __len__(self):
return len(self._args)
def __repr__(self):
return '<args %s>' % (repr(self._args))
def __getitem__(self, i):
try:
return self.all[i]
except IndexError:
return None
def __contains__(self, x):
return self.first(x) is not None
def get(self, x):
"""Returns argument at given index, else none."""
try:
return self.all[x]
except IndexError:
return None
def get_with(self, x):
"""Returns first argument that contains given string."""
return self.all[self.first_with(x)]
def remove(self, x):
"""Removes given arg (or list thereof) from Args object."""
def _remove(x):
found = self.first(x)
if found is not None:
self._args.pop(found)
if _is_collection(x):
for item in x:
_remove(x)
else:
_remove(x)
def pop(self, x):
"""Removes and Returns value at given index, else none."""
try:
return self._args.pop(x)
except IndexError:
return None
def any_contain(self, x):
"""Tests if given string is contained in any stored argument."""
return bool(self.first_with(x))
def contains(self, x):
"""Tests if given object is in arguments list.
Accepts strings and lists of strings."""
return self.__contains__(x)
def first(self, x):
"""Returns first found index of given value (or list of values)"""
def _find( x):
try:
return self.all.index(str(x))
except ValueError:
return None
if _is_collection(x):
for item in x:
found = _find(item)
if found is not None:
return found
return None
else:
return _find(x)
def first_with(self, x):
"""Returns first found index containing value (or list of values)"""
def _find(x):
try:
for arg in self.all:
if x in arg:
return self.all.index(arg)
except ValueError:
return None
if _is_collection(x):
for item in x:
found = _find(item)
if found:
return found
return None
else:
return _find(x)
def first_without(self, x):
"""Returns first found index not containing value (or list of values)"""
def _find(x):
try:
for arg in self.all:
if x not in arg:
return self.all.index(arg)
except ValueError:
return None
if _is_collection(x):
for item in x:
found = _find(item)
if found:
return found
return None
else:
return _find(x)
def start_with(self, x):
"""Returns all arguments beginning with given string (or list thereof)"""
_args = []
for arg in self.all:
if _is_collection(x):
for _x in x:
if arg.startswith(x):
_args.append(arg)
break
else:
if arg.startswith(x):
_args.append(arg)
return ArgsList(_args, no_argv=True)
def contains_at(self, x, index):
"""Tests if given [list of] string is at given index."""
try:
if _is_collection(x):
for _x in x:
if (_x in self.all[index]) or (_x == self.all[index]):
return True
else:
return False
else:
return (x in self.all[index])
except IndexError:
return False
def has(self, x):
"""Returns true if argument exists at given index.
Accepts: integer.
"""
try:
self.all[x]
return True
except IndexError:
return False
def value_after(self, x):
"""Returns value of argument after given found argument (or list thereof)."""
try:
try:
i = self.all.index(x)
except ValueError:
return None
return self.all[i + 1]
except IndexError:
return None
@property
def grouped(self):
"""Extracts --flag groups from argument list.
Returns {format: Args, ...}
"""
collection = OrderedDict(_=ArgsList(no_argv=True))
_current_group = None
for arg in self.all:
if arg.startswith('-'):
_current_group = arg
collection.setdefault(arg, ArgsList(no_argv=True))
else:
if _current_group:
collection[_current_group]._args.append(arg)
else:
collection['_']._args.append(arg)
return collection
@property
def last(self):
"""Returns last argument."""
try:
return self.all[-1]
except IndexError:
return None
@property
def all(self):
"""Returns all arguments."""
return self._args
def all_with(self, x):
"""Returns all arguments containing given string (or list thereof)"""
_args = []
for arg in self.all:
if _is_collection(x):
for _x in x:
if _x in arg:
_args.append(arg)
break
else:
if x in arg:
_args.append(arg)
return ArgsList(_args, no_argv=True)
def all_without(self, x):
"""Returns all arguments not containing given string (or list thereof)"""
_args = []
for arg in self.all:
if _is_collection(x):
for _x in x:
if _x not in arg:
_args.append(arg)
break
else:
if x not in arg:
_args.append(arg)
return ArgsList(_args, no_argv=True)
@property
def flags(self):
"""Returns Arg object including only flagged arguments."""
return self.start_with('-')
@property
def not_flags(self):
"""Returns Arg object excluding flagged arguments."""
return self.all_without('-')
@property
def files(self, absolute=False):
"""Returns an expanded list of all valid paths that were passed in."""
_paths = []
for arg in self.all:
for path in _expand_path(arg):
if os.path.exists(path):
if absolute:
_paths.append(os.path.abspath(path))
else:
_paths.append(path)
return _paths
@property
def not_files(self):
"""Returns a list of all arguments that aren't files/globs."""
_args = []
for arg in self.all:
if not len(_expand_path(arg)):
if not os.path.exists(arg):
_args.append(arg)
return ArgsList(_args, no_argv=True)
@property
def copy(self):
"""Returns a copy of Args object for temporary manipulation."""
return ArgsList(self.all)
args = ArgsList()
get = args.get
get_with = args.get_with
remove = args.remove
pop = args.pop
any_contain = args.any_contain
contains = args.contains
first = args.first
first_with = args.first_with
first_without = args.first_without
start_with = args.start_with
contains_at = args.contains_at
has = args.has
value_after = args.value_after
grouped = args.grouped
last = args.last
all = args.all
all_with = args.all_with
all_without = args.all_without
flags = args.flags
not_flags = args.not_flags
files = args.files
not_files = args.not_files
copy = args.copy
\ No newline at end of file
"""
Cycler
======
Cycling through combinations of values, producing dictionaries.
You can add cyclers::
from cycler import cycler
cc = (cycler(color=list('rgb')) +
cycler(linestyle=['-', '--', '-.']))
for d in cc:
print(d)
Results in::
{'color': 'r', 'linestyle': '-'}
{'color': 'g', 'linestyle': '--'}
{'color': 'b', 'linestyle': '-.'}
You can multiply cyclers::
from cycler import cycler
cc = (cycler(color=list('rgb')) *
cycler(linestyle=['-', '--', '-.']))
for d in cc:
print(d)
Results in::
{'color': 'r', 'linestyle': '-'}
{'color': 'r', 'linestyle': '--'}
{'color': 'r', 'linestyle': '-.'}
{'color': 'g', 'linestyle': '-'}
{'color': 'g', 'linestyle': '--'}
{'color': 'g', 'linestyle': '-.'}
{'color': 'b', 'linestyle': '-'}
{'color': 'b', 'linestyle': '--'}
{'color': 'b', 'linestyle': '-.'}
"""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import six
from itertools import product, cycle
from six.moves import zip, reduce
from operator import mul, add
import copy
__version__ = '0.10.0'
def _process_keys(left, right):
"""
Helper function to compose cycler keys
Parameters
----------
left, right : iterable of dictionaries or None
The cyclers to be composed
Returns
-------
keys : set
The keys in the composition of the two cyclers
"""
l_peek = next(iter(left)) if left is not None else {}
r_peek = next(iter(right)) if right is not None else {}
l_key = set(l_peek.keys())
r_key = set(r_peek.keys())
if l_key & r_key:
raise ValueError("Can not compose overlapping cycles")
return l_key | r_key
class Cycler(object):
"""
Composable cycles
This class has compositions methods:
``+``
for 'inner' products (zip)
``+=``
in-place ``+``
``*``
for outer products (itertools.product) and integer multiplication
``*=``
in-place ``*``
and supports basic slicing via ``[]``
Parameters
----------
left : Cycler or None
The 'left' cycler
right : Cycler or None
The 'right' cycler
op : func or None
Function which composes the 'left' and 'right' cyclers.
"""
def __call__(self):
return cycle(self)
def __init__(self, left, right=None, op=None):
"""Semi-private init
Do not use this directly, use `cycler` function instead.
"""
if isinstance(left, Cycler):
self._left = Cycler(left._left, left._right, left._op)
elif left is not None:
# Need to copy the dictionary or else that will be a residual
# mutable that could lead to strange errors
self._left = [copy.copy(v) for v in left]
else:
self._left = None
if isinstance(right, Cycler):
self._right = Cycler(right._left, right._right, right._op)
elif right is not None:
# Need to copy the dictionary or else that will be a residual
# mutable that could lead to strange errors
self._right = [copy.copy(v) for v in right]
else:
self._right = None
self._keys = _process_keys(self._left, self._right)
self._op = op
@property
def keys(self):
"""
The keys this Cycler knows about
"""
return set(self._keys)
def change_key(self, old, new):
"""
Change a key in this cycler to a new name.
Modification is performed in-place.
Does nothing if the old key is the same as the new key.
Raises a ValueError if the new key is already a key.
Raises a KeyError if the old key isn't a key.
"""
if old == new:
return
if new in self._keys:
raise ValueError("Can't replace %s with %s, %s is already a key" %
(old, new, new))
if old not in self._keys:
raise KeyError("Can't replace %s with %s, %s is not a key" %
(old, new, old))
self._keys.remove(old)
self._keys.add(new)
if self._right is not None and old in self._right.keys:
self._right.change_key(old, new)
# self._left should always be non-None
# if self._keys is non-empty.
elif isinstance(self._left, Cycler):
self._left.change_key(old, new)
else:
# It should be completely safe at this point to
# assume that the old key can be found in each
# iteration.
self._left = [{new: entry[old]} for entry in self._left]
def _compose(self):
"""
Compose the 'left' and 'right' components of this cycle
with the proper operation (zip or product as of now)
"""
for a, b in self._op(self._left, self._right):
out = dict()
out.update(a)
out.update(b)
yield out
@classmethod
def _from_iter(cls, label, itr):
"""
Class method to create 'base' Cycler objects
that do not have a 'right' or 'op' and for which
the 'left' object is not another Cycler.
Parameters
----------
label : str
The property key.
itr : iterable
Finite length iterable of the property values.
Returns
-------
cycler : Cycler
New 'base' `Cycler`
"""
ret = cls(None)
ret._left = list({label: v} for v in itr)
ret._keys = set([label])
return ret
def __getitem__(self, key):
# TODO : maybe add numpy style fancy slicing
if isinstance(key, slice):
trans = self.by_key()
return reduce(add, (_cycler(k, v[key])
for k, v in six.iteritems(trans)))
else:
raise ValueError("Can only use slices with Cycler.__getitem__")
def __iter__(self):
if self._right is None:
return iter(dict(l) for l in self._left)
return self._compose()
def __add__(self, other):
"""
Pair-wise combine two equal length cycles (zip)
Parameters
----------
other : Cycler
The second Cycler
"""
if len(self) != len(other):
raise ValueError("Can only add equal length cycles, "
"not {0} and {1}".format(len(self), len(other)))
return Cycler(self, other, zip)
def __mul__(self, other):
"""
Outer product of two cycles (`itertools.product`) or integer
multiplication.
Parameters
----------
other : Cycler or int
The second Cycler or integer
"""
if isinstance(other, Cycler):
return Cycler(self, other, product)
elif isinstance(other, int):
trans = self.by_key()
return reduce(add, (_cycler(k, v*other)
for k, v in six.iteritems(trans)))
else:
return NotImplemented
def __rmul__(self, other):
return self * other
def __len__(self):
op_dict = {zip: min, product: mul}
if self._right is None:
return len(self._left)
l_len = len(self._left)
r_len = len(self._right)
return op_dict[self._op](l_len, r_len)
def __iadd__(self, other):
"""
In-place pair-wise combine two equal length cycles (zip)
Parameters
----------
other : Cycler
The second Cycler
"""
if not isinstance(other, Cycler):
raise TypeError("Cannot += with a non-Cycler object")
# True shallow copy of self is fine since this is in-place
old_self = copy.copy(self)
self._keys = _process_keys(old_self, other)
self._left = old_self
self._op = zip
self._right = Cycler(other._left, other._right, other._op)
return self
def __imul__(self, other):
"""
In-place outer product of two cycles (`itertools.product`)
Parameters
----------
other : Cycler
The second Cycler
"""
if not isinstance(other, Cycler):
raise TypeError("Cannot *= with a non-Cycler object")
# True shallow copy of self is fine since this is in-place
old_self = copy.copy(self)
self._keys = _process_keys(old_self, other)
self._left = old_self
self._op = product
self._right = Cycler(other._left, other._right, other._op)
return self
def __eq__(self, other):
"""
Check equality
"""
if len(self) != len(other):
return False
if self.keys ^ other.keys:
return False
return all(a == b for a, b in zip(self, other))
def __repr__(self):
op_map = {zip: '+', product: '*'}
if self._right is None:
lab = self.keys.pop()
itr = list(v[lab] for v in self)
return "cycler({lab!r}, {itr!r})".format(lab=lab, itr=itr)
else:
op = op_map.get(self._op, '?')
msg = "({left!r} {op} {right!r})"
return msg.format(left=self._left, op=op, right=self._right)
def _repr_html_(self):
# an table showing the value of each key through a full cycle
output = "<table>"
sorted_keys = sorted(self.keys, key=repr)
for key in sorted_keys:
output += "<th>{key!r}</th>".format(key=key)
for d in iter(self):
output += "<tr>"
for k in sorted_keys:
output += "<td>{val!r}</td>".format(val=d[k])
output += "</tr>"
output += "</table>"
return output
def by_key(self):
"""Values by key
This returns the transposed values of the cycler. Iterating
over a `Cycler` yields dicts with a single value for each key,
this method returns a `dict` of `list` which are the values
for the given key.
The returned value can be used to create an equivalent `Cycler`
using only `+`.
Returns
-------
transpose : dict
dict of lists of the values for each key.
"""
# TODO : sort out if this is a bottle neck, if there is a better way
# and if we care.
keys = self.keys
# change this to dict comprehension when drop 2.6
out = dict((k, list()) for k in keys)
for d in self:
for k in keys:
out[k].append(d[k])
return out
# for back compatibility
_transpose = by_key
def simplify(self):
"""Simplify the Cycler
Returned as a composition using only sums (no multiplications)
Returns
-------
simple : Cycler
An equivalent cycler using only summation"""
# TODO: sort out if it is worth the effort to make sure this is
# balanced. Currently it is is
# (((a + b) + c) + d) vs
# ((a + b) + (c + d))
# I would believe that there is some performance implications
trans = self.by_key()
return reduce(add, (_cycler(k, v) for k, v in six.iteritems(trans)))
def concat(self, other):
"""Concatenate this cycler and an other.
The keys must match exactly.
This returns a single Cycler which is equivalent to
`itertools.chain(self, other)`
Examples
--------
>>> num = cycler('a', range(3))
>>> let = cycler('a', 'abc')
>>> num.concat(let)
cycler('a', [0, 1, 2, 'a', 'b', 'c'])
Parameters
----------
other : `Cycler`
The `Cycler` to concatenate to this one.
Returns
-------
ret : `Cycler`
The concatenated `Cycler`
"""
return concat(self, other)
def concat(left, right):
"""Concatenate two cyclers.
The keys must match exactly.
This returns a single Cycler which is equivalent to
`itertools.chain(left, right)`
Examples
--------
>>> num = cycler('a', range(3))
>>> let = cycler('a', 'abc')
>>> num.concat(let)
cycler('a', [0, 1, 2, 'a', 'b', 'c'])
Parameters
----------
left, right : `Cycler`
The two `Cycler` instances to concatenate
Returns
-------
ret : `Cycler`
The concatenated `Cycler`
"""
if left.keys != right.keys:
msg = '\n\t'.join(["Keys do not match:",
"Intersection: {both!r}",
"Disjoint: {just_one!r}"]).format(
both=left.keys & right.keys,
just_one=left.keys ^ right.keys)
raise ValueError(msg)
_l = left.by_key()
_r = right.by_key()
return reduce(add, (_cycler(k, _l[k] + _r[k]) for k in left.keys))
def cycler(*args, **kwargs):
"""
Create a new `Cycler` object from a single positional argument,
a pair of positional arguments, or the combination of keyword arguments.
cycler(arg)
cycler(label1=itr1[, label2=iter2[, ...]])
cycler(label, itr)
Form 1 simply copies a given `Cycler` object.
Form 2 composes a `Cycler` as an inner product of the
pairs of keyword arguments. In other words, all of the
iterables are cycled simultaneously, as if through zip().
Form 3 creates a `Cycler` from a label and an iterable.
This is useful for when the label cannot be a keyword argument
(e.g., an integer or a name that has a space in it).
Parameters
----------
arg : Cycler
Copy constructor for Cycler (does a shallow copy of iterables).
label : name
The property key. In the 2-arg form of the function,
the label can be any hashable object. In the keyword argument
form of the function, it must be a valid python identifier.
itr : iterable
Finite length iterable of the property values.
Can be a single-property `Cycler` that would
be like a key change, but as a shallow copy.
Returns
-------
cycler : Cycler
New `Cycler` for the given property
"""
if args and kwargs:
raise TypeError("cyl() can only accept positional OR keyword "
"arguments -- not both.")
if len(args) == 1:
if not isinstance(args[0], Cycler):
raise TypeError("If only one positional argument given, it must "
" be a Cycler instance.")
return Cycler(args[0])
elif len(args) == 2:
return _cycler(*args)
elif len(args) > 2:
raise TypeError("Only a single Cycler can be accepted as the lone "
"positional argument. Use keyword arguments instead.")
if kwargs:
return reduce(add, (_cycler(k, v) for k, v in six.iteritems(kwargs)))
raise TypeError("Must have at least a positional OR keyword arguments")
def _cycler(label, itr):
"""
Create a new `Cycler` object from a property name and
iterable of values.
Parameters
----------
label : hashable
The property key.
itr : iterable
Finite length iterable of the property values.
Returns
-------
cycler : Cycler
New `Cycler` for the given property
"""
if isinstance(itr, Cycler):
keys = itr.keys
if len(keys) != 1:
msg = "Can not create Cycler from a multi-property Cycler"
raise ValueError(msg)
lab = keys.pop()
# Doesn't need to be a new list because
# _from_iter() will be creating that new list anyway.
itr = (v[lab] for v in itr)
return Cycler._from_iter(label, itr)
# Copyright 2015,2016,2017 Nir Cohen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The ``distro`` package (``distro`` stands for Linux Distribution) provides
information about the Linux distribution it runs on, such as a reliable
machine-readable distro ID, or version information.
It is the recommended replacement for Python's original
:py:func:`platform.linux_distribution` function, but it provides much more
functionality. An alternative implementation became necessary because Python
3.5 deprecated this function, and Python 3.8 removed it altogether. Its
predecessor function :py:func:`platform.dist` was already deprecated since
Python 2.6 and removed in Python 3.8. Still, there are many cases in which
access to OS distribution information is needed. See `Python issue 1322
<https://bugs.python.org/issue1322>`_ for more information.
"""
import argparse
import json
import logging
import os
import re
import shlex
import subprocess
import sys
import warnings
__version__ = "1.6.0"
# Use `if False` to avoid an ImportError on Python 2. After dropping Python 2
# support, can use typing.TYPE_CHECKING instead. See:
# https://docs.python.org/3/library/typing.html#typing.TYPE_CHECKING
if False: # pragma: nocover
from typing import (
Any,
Callable,
Dict,
Iterable,
Optional,
Sequence,
TextIO,
Tuple,
Type,
TypedDict,
Union,
)
VersionDict = TypedDict(
"VersionDict", {"major": str, "minor": str, "build_number": str}
)
InfoDict = TypedDict(
"InfoDict",
{
"id": str,
"version": str,
"version_parts": VersionDict,
"like": str,
"codename": str,
},
)
_UNIXCONFDIR = os.environ.get("UNIXCONFDIR", "/etc")
_UNIXUSRLIBDIR = os.environ.get("UNIXUSRLIBDIR", "/usr/lib")
_OS_RELEASE_BASENAME = "os-release"
#: Translation table for normalizing the "ID" attribute defined in os-release
#: files, for use by the :func:`distro.id` method.
#:
#: * Key: Value as defined in the os-release file, translated to lower case,
#: with blanks translated to underscores.
#:
#: * Value: Normalized value.
NORMALIZED_OS_ID = {
"ol": "oracle", # Oracle Linux
}
#: Translation table for normalizing the "Distributor ID" attribute returned by
#: the lsb_release command, for use by the :func:`distro.id` method.
#:
#: * Key: Value as returned by the lsb_release command, translated to lower
#: case, with blanks translated to underscores.
#:
#: * Value: Normalized value.
NORMALIZED_LSB_ID = {
"enterpriseenterpriseas": "oracle", # Oracle Enterprise Linux 4
"enterpriseenterpriseserver": "oracle", # Oracle Linux 5
"redhatenterpriseworkstation": "rhel", # RHEL 6, 7 Workstation
"redhatenterpriseserver": "rhel", # RHEL 6, 7 Server
"redhatenterprisecomputenode": "rhel", # RHEL 6 ComputeNode
}
#: Translation table for normalizing the distro ID derived from the file name
#: of distro release files, for use by the :func:`distro.id` method.
#:
#: * Key: Value as derived from the file name of a distro release file,
#: translated to lower case, with blanks translated to underscores.
#:
#: * Value: Normalized value.
NORMALIZED_DISTRO_ID = {
"redhat": "rhel", # RHEL 6.x, 7.x
}
# Pattern for content of distro release file (reversed)
_DISTRO_RELEASE_CONTENT_REVERSED_PATTERN = re.compile(
r"(?:[^)]*\)(.*)\()? *(?:STL )?([\d.+\-a-z]*\d) *(?:esaeler *)?(.+)"
)
# Pattern for base file name of distro release file
_DISTRO_RELEASE_BASENAME_PATTERN = re.compile(r"(\w+)[-_](release|version)$")
# Base file names to be ignored when searching for distro release file
_DISTRO_RELEASE_IGNORE_BASENAMES = (
"debian_version",
"lsb-release",
"oem-release",
_OS_RELEASE_BASENAME,
"system-release",
"plesk-release",
"iredmail-release",
)
def linux_distribution(full_distribution_name=True):
# type: (bool) -> Tuple[str, str, str]
"""
.. deprecated:: 1.6.0
:func:`distro.linux_distribution()` is deprecated. It should only be
used as a compatibility shim with Python's
:py:func:`platform.linux_distribution()`. Please use :func:`distro.id`,
:func:`distro.version` and :func:`distro.name` instead.
Return information about the current OS distribution as a tuple
``(id_name, version, codename)`` with items as follows:
* ``id_name``: If *full_distribution_name* is false, the result of
:func:`distro.id`. Otherwise, the result of :func:`distro.name`.
* ``version``: The result of :func:`distro.version`.
* ``codename``: The result of :func:`distro.codename`.
The interface of this function is compatible with the original
:py:func:`platform.linux_distribution` function, supporting a subset of
its parameters.
The data it returns may not exactly be the same, because it uses more data
sources than the original function, and that may lead to different data if
the OS distribution is not consistent across multiple data sources it
provides (there are indeed such distributions ...).
Another reason for differences is the fact that the :func:`distro.id`
method normalizes the distro ID string to a reliable machine-readable value
for a number of popular OS distributions.
"""
warnings.warn(
"distro.linux_distribution() is deprecated. It should only be used as a "
"compatibility shim with Python's platform.linux_distribution(). Please use "
"distro.id(), distro.version() and distro.name() instead.",
DeprecationWarning,
stacklevel=2,
)
return _distro.linux_distribution(full_distribution_name)
def id():
# type: () -> str
"""
Return the distro ID of the current distribution, as a
machine-readable string.
For a number of OS distributions, the returned distro ID value is
*reliable*, in the sense that it is documented and that it does not change
across releases of the distribution.
This package maintains the following reliable distro ID values:
============== =========================================
Distro ID Distribution
============== =========================================
"ubuntu" Ubuntu
"debian" Debian
"rhel" RedHat Enterprise Linux
"centos" CentOS
"fedora" Fedora
"sles" SUSE Linux Enterprise Server
"opensuse" openSUSE
"amazon" Amazon Linux
"arch" Arch Linux
"cloudlinux" CloudLinux OS
"exherbo" Exherbo Linux
"gentoo" GenToo Linux
"ibm_powerkvm" IBM PowerKVM
"kvmibm" KVM for IBM z Systems
"linuxmint" Linux Mint
"mageia" Mageia
"mandriva" Mandriva Linux
"parallels" Parallels
"pidora" Pidora
"raspbian" Raspbian
"oracle" Oracle Linux (and Oracle Enterprise Linux)
"scientific" Scientific Linux
"slackware" Slackware
"xenserver" XenServer
"openbsd" OpenBSD
"netbsd" NetBSD
"freebsd" FreeBSD
"midnightbsd" MidnightBSD
============== =========================================
If you have a need to get distros for reliable IDs added into this set,
or if you find that the :func:`distro.id` function returns a different
distro ID for one of the listed distros, please create an issue in the
`distro issue tracker`_.
**Lookup hierarchy and transformations:**
First, the ID is obtained from the following sources, in the specified
order. The first available and non-empty value is used:
* the value of the "ID" attribute of the os-release file,
* the value of the "Distributor ID" attribute returned by the lsb_release
command,
* the first part of the file name of the distro release file,
The so determined ID value then passes the following transformations,
before it is returned by this method:
* it is translated to lower case,
* blanks (which should not be there anyway) are translated to underscores,
* a normalization of the ID is performed, based upon
`normalization tables`_. The purpose of this normalization is to ensure
that the ID is as reliable as possible, even across incompatible changes
in the OS distributions. A common reason for an incompatible change is
the addition of an os-release file, or the addition of the lsb_release
command, with ID values that differ from what was previously determined
from the distro release file name.
"""
return _distro.id()
def name(pretty=False):
# type: (bool) -> str
"""
Return the name of the current OS distribution, as a human-readable
string.
If *pretty* is false, the name is returned without version or codename.
(e.g. "CentOS Linux")
If *pretty* is true, the version and codename are appended.
(e.g. "CentOS Linux 7.1.1503 (Core)")
**Lookup hierarchy:**
The name is obtained from the following sources, in the specified order.
The first available and non-empty value is used:
* If *pretty* is false:
- the value of the "NAME" attribute of the os-release file,
- the value of the "Distributor ID" attribute returned by the lsb_release
command,
- the value of the "<name>" field of the distro release file.
* If *pretty* is true:
- the value of the "PRETTY_NAME" attribute of the os-release file,
- the value of the "Description" attribute returned by the lsb_release
command,
- the value of the "<name>" field of the distro release file, appended
with the value of the pretty version ("<version_id>" and "<codename>"
fields) of the distro release file, if available.
"""
return _distro.name(pretty)
def version(pretty=False, best=False):
# type: (bool, bool) -> str
"""
Return the version of the current OS distribution, as a human-readable
string.
If *pretty* is false, the version is returned without codename (e.g.
"7.0").
If *pretty* is true, the codename in parenthesis is appended, if the
codename is non-empty (e.g. "7.0 (Maipo)").
Some distributions provide version numbers with different precisions in
the different sources of distribution information. Examining the different
sources in a fixed priority order does not always yield the most precise
version (e.g. for Debian 8.2, or CentOS 7.1).
The *best* parameter can be used to control the approach for the returned
version:
If *best* is false, the first non-empty version number in priority order of
the examined sources is returned.
If *best* is true, the most precise version number out of all examined
sources is returned.
**Lookup hierarchy:**
In all cases, the version number is obtained from the following sources.
If *best* is false, this order represents the priority order:
* the value of the "VERSION_ID" attribute of the os-release file,
* the value of the "Release" attribute returned by the lsb_release
command,
* the version number parsed from the "<version_id>" field of the first line
of the distro release file,
* the version number parsed from the "PRETTY_NAME" attribute of the
os-release file, if it follows the format of the distro release files.
* the version number parsed from the "Description" attribute returned by
the lsb_release command, if it follows the format of the distro release
files.
"""
return _distro.version(pretty, best)
def version_parts(best=False):
# type: (bool) -> Tuple[str, str, str]
"""
Return the version of the current OS distribution as a tuple
``(major, minor, build_number)`` with items as follows:
* ``major``: The result of :func:`distro.major_version`.
* ``minor``: The result of :func:`distro.minor_version`.
* ``build_number``: The result of :func:`distro.build_number`.
For a description of the *best* parameter, see the :func:`distro.version`
method.
"""
return _distro.version_parts(best)
def major_version(best=False):
# type: (bool) -> str
"""
Return the major version of the current OS distribution, as a string,
if provided.
Otherwise, the empty string is returned. The major version is the first
part of the dot-separated version string.
For a description of the *best* parameter, see the :func:`distro.version`
method.
"""
return _distro.major_version(best)
def minor_version(best=False):
# type: (bool) -> str
"""
Return the minor version of the current OS distribution, as a string,
if provided.
Otherwise, the empty string is returned. The minor version is the second
part of the dot-separated version string.
For a description of the *best* parameter, see the :func:`distro.version`
method.
"""
return _distro.minor_version(best)
def build_number(best=False):
# type: (bool) -> str
"""
Return the build number of the current OS distribution, as a string,
if provided.
Otherwise, the empty string is returned. The build number is the third part
of the dot-separated version string.
For a description of the *best* parameter, see the :func:`distro.version`
method.
"""
return _distro.build_number(best)
def like():
# type: () -> str
"""
Return a space-separated list of distro IDs of distributions that are
closely related to the current OS distribution in regards to packaging
and programming interfaces, for example distributions the current
distribution is a derivative from.
**Lookup hierarchy:**
This information item is only provided by the os-release file.
For details, see the description of the "ID_LIKE" attribute in the
`os-release man page
<http://www.freedesktop.org/software/systemd/man/os-release.html>`_.
"""
return _distro.like()
def codename():
# type: () -> str
"""
Return the codename for the release of the current OS distribution,
as a string.
If the distribution does not have a codename, an empty string is returned.
Note that the returned codename is not always really a codename. For
example, openSUSE returns "x86_64". This function does not handle such
cases in any special way and just returns the string it finds, if any.
**Lookup hierarchy:**
* the codename within the "VERSION" attribute of the os-release file, if
provided,
* the value of the "Codename" attribute returned by the lsb_release
command,
* the value of the "<codename>" field of the distro release file.
"""
return _distro.codename()
def info(pretty=False, best=False):
# type: (bool, bool) -> InfoDict
"""
Return certain machine-readable information items about the current OS
distribution in a dictionary, as shown in the following example:
.. sourcecode:: python
{
'id': 'rhel',
'version': '7.0',
'version_parts': {
'major': '7',
'minor': '0',
'build_number': ''
},
'like': 'fedora',
'codename': 'Maipo'
}
The dictionary structure and keys are always the same, regardless of which
information items are available in the underlying data sources. The values
for the various keys are as follows:
* ``id``: The result of :func:`distro.id`.
* ``version``: The result of :func:`distro.version`.
* ``version_parts -> major``: The result of :func:`distro.major_version`.
* ``version_parts -> minor``: The result of :func:`distro.minor_version`.
* ``version_parts -> build_number``: The result of
:func:`distro.build_number`.
* ``like``: The result of :func:`distro.like`.
* ``codename``: The result of :func:`distro.codename`.
For a description of the *pretty* and *best* parameters, see the
:func:`distro.version` method.
"""
return _distro.info(pretty, best)
def os_release_info():
# type: () -> Dict[str, str]
"""
Return a dictionary containing key-value pairs for the information items
from the os-release file data source of the current OS distribution.
See `os-release file`_ for details about these information items.
"""
return _distro.os_release_info()
def lsb_release_info():
# type: () -> Dict[str, str]
"""
Return a dictionary containing key-value pairs for the information items
from the lsb_release command data source of the current OS distribution.
See `lsb_release command output`_ for details about these information
items.
"""
return _distro.lsb_release_info()
def distro_release_info():
# type: () -> Dict[str, str]
"""
Return a dictionary containing key-value pairs for the information items
from the distro release file data source of the current OS distribution.
See `distro release file`_ for details about these information items.
"""
return _distro.distro_release_info()
def uname_info():
# type: () -> Dict[str, str]
"""
Return a dictionary containing key-value pairs for the information items
from the distro release file data source of the current OS distribution.
"""
return _distro.uname_info()
def os_release_attr(attribute):
# type: (str) -> str
"""
Return a single named information item from the os-release file data source
of the current OS distribution.
Parameters:
* ``attribute`` (string): Key of the information item.
Returns:
* (string): Value of the information item, if the item exists.
The empty string, if the item does not exist.
See `os-release file`_ for details about these information items.
"""
return _distro.os_release_attr(attribute)
def lsb_release_attr(attribute):
# type: (str) -> str
"""
Return a single named information item from the lsb_release command output
data source of the current OS distribution.
Parameters:
* ``attribute`` (string): Key of the information item.
Returns:
* (string): Value of the information item, if the item exists.
The empty string, if the item does not exist.
See `lsb_release command output`_ for details about these information
items.
"""
return _distro.lsb_release_attr(attribute)
def distro_release_attr(attribute):
# type: (str) -> str
"""
Return a single named information item from the distro release file
data source of the current OS distribution.
Parameters:
* ``attribute`` (string): Key of the information item.
Returns:
* (string): Value of the information item, if the item exists.
The empty string, if the item does not exist.
See `distro release file`_ for details about these information items.
"""
return _distro.distro_release_attr(attribute)
def uname_attr(attribute):
# type: (str) -> str
"""
Return a single named information item from the distro release file
data source of the current OS distribution.
Parameters:
* ``attribute`` (string): Key of the information item.
Returns:
* (string): Value of the information item, if the item exists.
The empty string, if the item does not exist.
"""
return _distro.uname_attr(attribute)
try:
from functools import cached_property
except ImportError:
# Python < 3.8
class cached_property(object): # type: ignore
"""A version of @property which caches the value. On access, it calls the
underlying function and sets the value in `__dict__` so future accesses
will not re-call the property.
"""
def __init__(self, f):
# type: (Callable[[Any], Any]) -> None
self._fname = f.__name__
self._f = f
def __get__(self, obj, owner):
# type: (Any, Type[Any]) -> Any
assert obj is not None, "call {} on an instance".format(self._fname)
ret = obj.__dict__[self._fname] = self._f(obj)
return ret
class LinuxDistribution(object):
"""
Provides information about a OS distribution.
This package creates a private module-global instance of this class with
default initialization arguments, that is used by the
`consolidated accessor functions`_ and `single source accessor functions`_.
By using default initialization arguments, that module-global instance
returns data about the current OS distribution (i.e. the distro this
package runs on).
Normally, it is not necessary to create additional instances of this class.
However, in situations where control is needed over the exact data sources
that are used, instances of this class can be created with a specific
distro release file, or a specific os-release file, or without invoking the
lsb_release command.
"""
def __init__(
self,
include_lsb=True,
os_release_file="",
distro_release_file="",
include_uname=True,
root_dir=None,
):
# type: (bool, str, str, bool, Optional[str]) -> None
"""
The initialization method of this class gathers information from the
available data sources, and stores that in private instance attributes.
Subsequent access to the information items uses these private instance
attributes, so that the data sources are read only once.
Parameters:
* ``include_lsb`` (bool): Controls whether the
`lsb_release command output`_ is included as a data source.
If the lsb_release command is not available in the program execution
path, the data source for the lsb_release command will be empty.
* ``os_release_file`` (string): The path name of the
`os-release file`_ that is to be used as a data source.
An empty string (the default) will cause the default path name to
be used (see `os-release file`_ for details).
If the specified or defaulted os-release file does not exist, the
data source for the os-release file will be empty.
* ``distro_release_file`` (string): The path name of the
`distro release file`_ that is to be used as a data source.
An empty string (the default) will cause a default search algorithm
to be used (see `distro release file`_ for details).
If the specified distro release file does not exist, or if no default
distro release file can be found, the data source for the distro
release file will be empty.
* ``include_uname`` (bool): Controls whether uname command output is
included as a data source. If the uname command is not available in
the program execution path the data source for the uname command will
be empty.
* ``root_dir`` (string): The absolute path to the root directory to use
to find distro-related information files.
Public instance attributes:
* ``os_release_file`` (string): The path name of the
`os-release file`_ that is actually used as a data source. The
empty string if no distro release file is used as a data source.
* ``distro_release_file`` (string): The path name of the
`distro release file`_ that is actually used as a data source. The
empty string if no distro release file is used as a data source.
* ``include_lsb`` (bool): The result of the ``include_lsb`` parameter.
This controls whether the lsb information will be loaded.
* ``include_uname`` (bool): The result of the ``include_uname``
parameter. This controls whether the uname information will
be loaded.
Raises:
* :py:exc:`IOError`: Some I/O issue with an os-release file or distro
release file.
* :py:exc:`subprocess.CalledProcessError`: The lsb_release command had
some issue (other than not being available in the program execution
path).
* :py:exc:`UnicodeError`: A data source has unexpected characters or
uses an unexpected encoding.
"""
self.root_dir = root_dir
self.etc_dir = os.path.join(root_dir, "etc") if root_dir else _UNIXCONFDIR
self.usr_lib_dir = (
os.path.join(root_dir, "usr/lib") if root_dir else _UNIXUSRLIBDIR
)
if os_release_file:
self.os_release_file = os_release_file
else:
etc_dir_os_release_file = os.path.join(self.etc_dir, _OS_RELEASE_BASENAME)
usr_lib_os_release_file = os.path.join(
self.usr_lib_dir, _OS_RELEASE_BASENAME
)
# NOTE: The idea is to respect order **and** have it set
# at all times for API backwards compatibility.
if os.path.isfile(etc_dir_os_release_file) or not os.path.isfile(
usr_lib_os_release_file
):
self.os_release_file = etc_dir_os_release_file
else:
self.os_release_file = usr_lib_os_release_file
self.distro_release_file = distro_release_file or "" # updated later
self.include_lsb = include_lsb
self.include_uname = include_uname
def __repr__(self):
# type: () -> str
"""Return repr of all info"""
return (
"LinuxDistribution("
"os_release_file={self.os_release_file!r}, "
"distro_release_file={self.distro_release_file!r}, "
"include_lsb={self.include_lsb!r}, "
"include_uname={self.include_uname!r}, "
"_os_release_info={self._os_release_info!r}, "
"_lsb_release_info={self._lsb_release_info!r}, "
"_distro_release_info={self._distro_release_info!r}, "
"_uname_info={self._uname_info!r})".format(self=self)
)
def linux_distribution(self, full_distribution_name=True):
# type: (bool) -> Tuple[str, str, str]
"""
Return information about the OS distribution that is compatible
with Python's :func:`platform.linux_distribution`, supporting a subset
of its parameters.
For details, see :func:`distro.linux_distribution`.
"""
return (
self.name() if full_distribution_name else self.id(),
self.version(),
self.codename(),
)
def id(self):
# type: () -> str
"""Return the distro ID of the OS distribution, as a string.
For details, see :func:`distro.id`.
"""
def normalize(distro_id, table):
# type: (str, Dict[str, str]) -> str
distro_id = distro_id.lower().replace(" ", "_")
return table.get(distro_id, distro_id)
distro_id = self.os_release_attr("id")
if distro_id:
return normalize(distro_id, NORMALIZED_OS_ID)
distro_id = self.lsb_release_attr("distributor_id")
if distro_id:
return normalize(distro_id, NORMALIZED_LSB_ID)
distro_id = self.distro_release_attr("id")
if distro_id:
return normalize(distro_id, NORMALIZED_DISTRO_ID)
distro_id = self.uname_attr("id")
if distro_id:
return normalize(distro_id, NORMALIZED_DISTRO_ID)
return ""
def name(self, pretty=False):
# type: (bool) -> str
"""
Return the name of the OS distribution, as a string.
For details, see :func:`distro.name`.
"""
name = (
self.os_release_attr("name")
or self.lsb_release_attr("distributor_id")
or self.distro_release_attr("name")
or self.uname_attr("name")
)
if pretty:
name = self.os_release_attr("pretty_name") or self.lsb_release_attr(
"description"
)
if not name:
name = self.distro_release_attr("name") or self.uname_attr("name")
version = self.version(pretty=True)
if version:
name = name + " " + version
return name or ""
def version(self, pretty=False, best=False):
# type: (bool, bool) -> str
"""
Return the version of the OS distribution, as a string.
For details, see :func:`distro.version`.
"""
versions = [
self.os_release_attr("version_id"),
self.lsb_release_attr("release"),
self.distro_release_attr("version_id"),
self._parse_distro_release_content(self.os_release_attr("pretty_name")).get(
"version_id", ""
),
self._parse_distro_release_content(
self.lsb_release_attr("description")
).get("version_id", ""),
self.uname_attr("release"),
]
version = ""
if best:
# This algorithm uses the last version in priority order that has
# the best precision. If the versions are not in conflict, that
# does not matter; otherwise, using the last one instead of the
# first one might be considered a surprise.
for v in versions:
if v.count(".") > version.count(".") or version == "":
version = v
else:
for v in versions:
if v != "":
version = v
break
if pretty and version and self.codename():
version = "{0} ({1})".format(version, self.codename())
return version
def version_parts(self, best=False):
# type: (bool) -> Tuple[str, str, str]
"""
Return the version of the OS distribution, as a tuple of version
numbers.
For details, see :func:`distro.version_parts`.
"""
version_str = self.version(best=best)
if version_str:
version_regex = re.compile(r"(\d+)\.?(\d+)?\.?(\d+)?")
matches = version_regex.match(version_str)
if matches:
major, minor, build_number = matches.groups()
return major, minor or "", build_number or ""
return "", "", ""
def major_version(self, best=False):
# type: (bool) -> str
"""
Return the major version number of the current distribution.
For details, see :func:`distro.major_version`.
"""
return self.version_parts(best)[0]
def minor_version(self, best=False):
# type: (bool) -> str
"""
Return the minor version number of the current distribution.
For details, see :func:`distro.minor_version`.
"""
return self.version_parts(best)[1]
def build_number(self, best=False):
# type: (bool) -> str
"""
Return the build number of the current distribution.
For details, see :func:`distro.build_number`.
"""
return self.version_parts(best)[2]
def like(self):
# type: () -> str
"""
Return the IDs of distributions that are like the OS distribution.
For details, see :func:`distro.like`.
"""
return self.os_release_attr("id_like") or ""
def codename(self):
# type: () -> str
"""
Return the codename of the OS distribution.
For details, see :func:`distro.codename`.
"""
try:
# Handle os_release specially since distros might purposefully set
# this to empty string to have no codename
return self._os_release_info["codename"]
except KeyError:
return (
self.lsb_release_attr("codename")
or self.distro_release_attr("codename")
or ""
)
def info(self, pretty=False, best=False):
# type: (bool, bool) -> InfoDict
"""
Return certain machine-readable information about the OS
distribution.
For details, see :func:`distro.info`.
"""
return dict(
id=self.id(),
version=self.version(pretty, best),
version_parts=dict(
major=self.major_version(best),
minor=self.minor_version(best),
build_number=self.build_number(best),
),
like=self.like(),
codename=self.codename(),
)
def os_release_info(self):
# type: () -> Dict[str, str]
"""
Return a dictionary containing key-value pairs for the information
items from the os-release file data source of the OS distribution.
For details, see :func:`distro.os_release_info`.
"""
return self._os_release_info
def lsb_release_info(self):
# type: () -> Dict[str, str]
"""
Return a dictionary containing key-value pairs for the information
items from the lsb_release command data source of the OS
distribution.
For details, see :func:`distro.lsb_release_info`.
"""
return self._lsb_release_info
def distro_release_info(self):
# type: () -> Dict[str, str]
"""
Return a dictionary containing key-value pairs for the information
items from the distro release file data source of the OS
distribution.
For details, see :func:`distro.distro_release_info`.
"""
return self._distro_release_info
def uname_info(self):
# type: () -> Dict[str, str]
"""
Return a dictionary containing key-value pairs for the information
items from the uname command data source of the OS distribution.
For details, see :func:`distro.uname_info`.
"""
return self._uname_info
def os_release_attr(self, attribute):
# type: (str) -> str
"""
Return a single named information item from the os-release file data
source of the OS distribution.
For details, see :func:`distro.os_release_attr`.
"""
return self._os_release_info.get(attribute, "")
def lsb_release_attr(self, attribute):
# type: (str) -> str
"""
Return a single named information item from the lsb_release command
output data source of the OS distribution.
For details, see :func:`distro.lsb_release_attr`.
"""
return self._lsb_release_info.get(attribute, "")
def distro_release_attr(self, attribute):
# type: (str) -> str
"""
Return a single named information item from the distro release file
data source of the OS distribution.
For details, see :func:`distro.distro_release_attr`.
"""
return self._distro_release_info.get(attribute, "")
def uname_attr(self, attribute):
# type: (str) -> str
"""
Return a single named information item from the uname command
output data source of the OS distribution.
For details, see :func:`distro.uname_attr`.
"""
return self._uname_info.get(attribute, "")
@cached_property
def _os_release_info(self):
# type: () -> Dict[str, str]
"""
Get the information items from the specified os-release file.
Returns:
A dictionary containing all information items.
"""
if os.path.isfile(self.os_release_file):
with open(self.os_release_file) as release_file:
return self._parse_os_release_content(release_file)
return {}
@staticmethod
def _parse_os_release_content(lines):
# type: (TextIO) -> Dict[str, str]
"""
Parse the lines of an os-release file.
Parameters:
* lines: Iterable through the lines in the os-release file.
Each line must be a unicode string or a UTF-8 encoded byte
string.
Returns:
A dictionary containing all information items.
"""
props = {}
lexer = shlex.shlex(lines, posix=True)
lexer.whitespace_split = True
# The shlex module defines its `wordchars` variable using literals,
# making it dependent on the encoding of the Python source file.
# In Python 2.6 and 2.7, the shlex source file is encoded in
# 'iso-8859-1', and the `wordchars` variable is defined as a byte
# string. This causes a UnicodeDecodeError to be raised when the
# parsed content is a unicode object. The following fix resolves that
# (... but it should be fixed in shlex...):
if sys.version_info[0] == 2 and isinstance(lexer.wordchars, bytes):
lexer.wordchars = lexer.wordchars.decode("iso-8859-1")
tokens = list(lexer)
for token in tokens:
# At this point, all shell-like parsing has been done (i.e.
# comments processed, quotes and backslash escape sequences
# processed, multi-line values assembled, trailing newlines
# stripped, etc.), so the tokens are now either:
# * variable assignments: var=value
# * commands or their arguments (not allowed in os-release)
if "=" in token:
k, v = token.split("=", 1)
props[k.lower()] = v
else:
# Ignore any tokens that are not variable assignments
pass
if "version_codename" in props:
# os-release added a version_codename field. Use that in
# preference to anything else Note that some distros purposefully
# do not have code names. They should be setting
# version_codename=""
props["codename"] = props["version_codename"]
elif "ubuntu_codename" in props:
# Same as above but a non-standard field name used on older Ubuntus
props["codename"] = props["ubuntu_codename"]
elif "version" in props:
# If there is no version_codename, parse it from the version
match = re.search(r"(\(\D+\))|,(\s+)?\D+", props["version"])
if match:
codename = match.group()
codename = codename.strip("()")
codename = codename.strip(",")
codename = codename.strip()
# codename appears within paranthese.
props["codename"] = codename
return props
@cached_property
def _lsb_release_info(self):
# type: () -> Dict[str, str]
"""
Get the information items from the lsb_release command output.
Returns:
A dictionary containing all information items.
"""
if not self.include_lsb:
return {}
with open(os.devnull, "wb") as devnull:
try:
cmd = ("lsb_release", "-a")
stdout = subprocess.check_output(cmd, stderr=devnull)
# Command not found or lsb_release returned error
except (OSError, subprocess.CalledProcessError):
return {}
content = self._to_str(stdout).splitlines()
return self._parse_lsb_release_content(content)
@staticmethod
def _parse_lsb_release_content(lines):
# type: (Iterable[str]) -> Dict[str, str]
"""
Parse the output of the lsb_release command.
Parameters:
* lines: Iterable through the lines of the lsb_release output.
Each line must be a unicode string or a UTF-8 encoded byte
string.
Returns:
A dictionary containing all information items.
"""
props = {}
for line in lines:
kv = line.strip("\n").split(":", 1)
if len(kv) != 2:
# Ignore lines without colon.
continue
k, v = kv
props.update({k.replace(" ", "_").lower(): v.strip()})
return props
@cached_property
def _uname_info(self):
# type: () -> Dict[str, str]
with open(os.devnull, "wb") as devnull:
try:
cmd = ("uname", "-rs")
stdout = subprocess.check_output(cmd, stderr=devnull)
except OSError:
return {}
content = self._to_str(stdout).splitlines()
return self._parse_uname_content(content)
@staticmethod
def _parse_uname_content(lines):
# type: (Sequence[str]) -> Dict[str, str]
props = {}
match = re.search(r"^([^\s]+)\s+([\d\.]+)", lines[0].strip())
if match:
name, version = match.groups()
# This is to prevent the Linux kernel version from
# appearing as the 'best' version on otherwise
# identifiable distributions.
if name == "Linux":
return {}
props["id"] = name.lower()
props["name"] = name
props["release"] = version
return props
@staticmethod
def _to_str(text):
# type: (Union[bytes, str]) -> str
encoding = sys.getfilesystemencoding()
encoding = "utf-8" if encoding == "ascii" else encoding
if sys.version_info[0] >= 3:
if isinstance(text, bytes):
return text.decode(encoding)
else:
if isinstance(text, unicode): # noqa
return text.encode(encoding)
return text
@cached_property
def _distro_release_info(self):
# type: () -> Dict[str, str]
"""
Get the information items from the specified distro release file.
Returns:
A dictionary containing all information items.
"""
if self.distro_release_file:
# If it was specified, we use it and parse what we can, even if
# its file name or content does not match the expected pattern.
distro_info = self._parse_distro_release_file(self.distro_release_file)
basename = os.path.basename(self.distro_release_file)
# The file name pattern for user-specified distro release files
# is somewhat more tolerant (compared to when searching for the
# file), because we want to use what was specified as best as
# possible.
match = _DISTRO_RELEASE_BASENAME_PATTERN.match(basename)
if "name" in distro_info and "cloudlinux" in distro_info["name"].lower():
distro_info["id"] = "cloudlinux"
elif match:
distro_info["id"] = match.group(1)
return distro_info
else:
try:
basenames = os.listdir(self.etc_dir)
# We sort for repeatability in cases where there are multiple
# distro specific files; e.g. CentOS, Oracle, Enterprise all
# containing `redhat-release` on top of their own.
basenames.sort()
except OSError:
# This may occur when /etc is not readable but we can't be
# sure about the *-release files. Check common entries of
# /etc for information. If they turn out to not be there the
# error is handled in `_parse_distro_release_file()`.
basenames = [
"SuSE-release",
"arch-release",
"base-release",
"centos-release",
"fedora-release",
"gentoo-release",
"mageia-release",
"mandrake-release",
"mandriva-release",
"mandrivalinux-release",
"manjaro-release",
"oracle-release",
"redhat-release",
"sl-release",
"slackware-version",
]
for basename in basenames:
if basename in _DISTRO_RELEASE_IGNORE_BASENAMES:
continue
match = _DISTRO_RELEASE_BASENAME_PATTERN.match(basename)
if match:
filepath = os.path.join(self.etc_dir, basename)
distro_info = self._parse_distro_release_file(filepath)
if "name" in distro_info:
# The name is always present if the pattern matches
self.distro_release_file = filepath
distro_info["id"] = match.group(1)
if "cloudlinux" in distro_info["name"].lower():
distro_info["id"] = "cloudlinux"
return distro_info
return {}
def _parse_distro_release_file(self, filepath):
# type: (str) -> Dict[str, str]
"""
Parse a distro release file.
Parameters:
* filepath: Path name of the distro release file.
Returns:
A dictionary containing all information items.
"""
try:
with open(filepath) as fp:
# Only parse the first line. For instance, on SLES there
# are multiple lines. We don't want them...
return self._parse_distro_release_content(fp.readline())
except (OSError, IOError):
# Ignore not being able to read a specific, seemingly version
# related file.
# See https://github.com/python-distro/distro/issues/162
return {}
@staticmethod
def _parse_distro_release_content(line):
# type: (str) -> Dict[str, str]
"""
Parse a line from a distro release file.
Parameters:
* line: Line from the distro release file. Must be a unicode string
or a UTF-8 encoded byte string.
Returns:
A dictionary containing all information items.
"""
matches = _DISTRO_RELEASE_CONTENT_REVERSED_PATTERN.match(line.strip()[::-1])
distro_info = {}
if matches:
# regexp ensures non-None
distro_info["name"] = matches.group(3)[::-1]
if matches.group(2):
distro_info["version_id"] = matches.group(2)[::-1]
if matches.group(1):
distro_info["codename"] = matches.group(1)[::-1]
elif line:
distro_info["name"] = line.strip()
return distro_info
_distro = LinuxDistribution()
def main():
# type: () -> None
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))
parser = argparse.ArgumentParser(description="OS distro info tool")
parser.add_argument(
"--json", "-j", help="Output in machine readable format", action="store_true"
)
parser.add_argument(
"--root-dir",
"-r",
type=str,
dest="root_dir",
help="Path to the root filesystem directory (defaults to /)",
)
args = parser.parse_args()
if args.root_dir:
dist = LinuxDistribution(
include_lsb=False, include_uname=False, root_dir=args.root_dir
)
else:
dist = _distro
if args.json:
logger.info(json.dumps(dist.info(), indent=4, sort_keys=True))
else:
logger.info("Name: %s", dist.name(pretty=True))
distribution_version = dist.version(pretty=True)
logger.info("Version: %s", distribution_version)
distribution_codename = dist.codename()
logger.info("Codename: %s", distribution_codename)
if __name__ == "__main__":
main()
"""Pythonic command-line interface parser that will make you smile.
* http://docopt.org
* Repository and issue-tracker: https://github.com/docopt/docopt
* Licensed under terms of MIT license (see LICENSE-MIT)
* Copyright (c) 2013 Vladimir Keleshev, vladimir@keleshev.com
"""
import sys
import re
__all__ = ['docopt']
__version__ = '0.6.2'
class DocoptLanguageError(Exception):
"""Error in construction of usage-message by developer."""
class DocoptExit(SystemExit):
"""Exit in case user invoked program with incorrect arguments."""
usage = ''
def __init__(self, message=''):
SystemExit.__init__(self, (message + '\n' + self.usage).strip())
class Pattern(object):
def __eq__(self, other):
return repr(self) == repr(other)
def __hash__(self):
return hash(repr(self))
def fix(self):
self.fix_identities()
self.fix_repeating_arguments()
return self
def fix_identities(self, uniq=None):
"""Make pattern-tree tips point to same object if they are equal."""
if not hasattr(self, 'children'):
return self
uniq = list(set(self.flat())) if uniq is None else uniq
for i, c in enumerate(self.children):
if not hasattr(c, 'children'):
assert c in uniq
self.children[i] = uniq[uniq.index(c)]
else:
c.fix_identities(uniq)
def fix_repeating_arguments(self):
"""Fix elements that should accumulate/increment values."""
either = [list(c.children) for c in self.either.children]
for case in either:
for e in [c for c in case if case.count(c) > 1]:
if type(e) is Argument or type(e) is Option and e.argcount:
if e.value is None:
e.value = []
elif type(e.value) is not list:
e.value = e.value.split()
if type(e) is Command or type(e) is Option and e.argcount == 0:
e.value = 0
return self
@property
def either(self):
"""Transform pattern into an equivalent, with only top-level Either."""
# Currently the pattern will not be equivalent, but more "narrow",
# although good enough to reason about list arguments.
ret = []
groups = [[self]]
while groups:
children = groups.pop(0)
types = [type(c) for c in children]
if Either in types:
either = [c for c in children if type(c) is Either][0]
children.pop(children.index(either))
for c in either.children:
groups.append([c] + children)
elif Required in types:
required = [c for c in children if type(c) is Required][0]
children.pop(children.index(required))
groups.append(list(required.children) + children)
elif Optional in types:
optional = [c for c in children if type(c) is Optional][0]
children.pop(children.index(optional))
groups.append(list(optional.children) + children)
elif AnyOptions in types:
optional = [c for c in children if type(c) is AnyOptions][0]
children.pop(children.index(optional))
groups.append(list(optional.children) + children)
elif OneOrMore in types:
oneormore = [c for c in children if type(c) is OneOrMore][0]
children.pop(children.index(oneormore))
groups.append(list(oneormore.children) * 2 + children)
else:
ret.append(children)
return Either(*[Required(*e) for e in ret])
class ChildPattern(Pattern):
def __init__(self, name, value=None):
self.name = name
self.value = value
def __repr__(self):
return '%s(%r, %r)' % (self.__class__.__name__, self.name, self.value)
def flat(self, *types):
return [self] if not types or type(self) in types else []
def match(self, left, collected=None):
collected = [] if collected is None else collected
pos, match = self.single_match(left)
if match is None:
return False, left, collected
left_ = left[:pos] + left[pos + 1:]
same_name = [a for a in collected if a.name == self.name]
if type(self.value) in (int, list):
if type(self.value) is int:
increment = 1
else:
increment = ([match.value] if type(match.value) is str
else match.value)
if not same_name:
match.value = increment
return True, left_, collected + [match]
same_name[0].value += increment
return True, left_, collected
return True, left_, collected + [match]
class ParentPattern(Pattern):
def __init__(self, *children):
self.children = list(children)
def __repr__(self):
return '%s(%s)' % (self.__class__.__name__,
', '.join(repr(a) for a in self.children))
def flat(self, *types):
if type(self) in types:
return [self]
return sum([c.flat(*types) for c in self.children], [])
class Argument(ChildPattern):
def single_match(self, left):
for n, p in enumerate(left):
if type(p) is Argument:
return n, Argument(self.name, p.value)
return None, None
@classmethod
def parse(class_, source):
name = re.findall('(<\S*?>)', source)[0]
value = re.findall('\[default: (.*)\]', source, flags=re.I)
return class_(name, value[0] if value else None)
class Command(Argument):
def __init__(self, name, value=False):
self.name = name
self.value = value
def single_match(self, left):
for n, p in enumerate(left):
if type(p) is Argument:
if p.value == self.name:
return n, Command(self.name, True)
else:
break
return None, None
class Option(ChildPattern):
def __init__(self, short=None, long=None, argcount=0, value=False):
assert argcount in (0, 1)
self.short, self.long = short, long
self.argcount, self.value = argcount, value
self.value = None if value is False and argcount else value
@classmethod
def parse(class_, option_description):
short, long, argcount, value = None, None, 0, False
options, _, description = option_description.strip().partition(' ')
options = options.replace(',', ' ').replace('=', ' ')
for s in options.split():
if s.startswith('--'):
long = s
elif s.startswith('-'):
short = s
else:
argcount = 1
if argcount:
matched = re.findall('\[default: (.*)\]', description, flags=re.I)
value = matched[0] if matched else None
return class_(short, long, argcount, value)
def single_match(self, left):
for n, p in enumerate(left):
if self.name == p.name:
return n, p
return None, None
@property
def name(self):
return self.long or self.short
def __repr__(self):
return 'Option(%r, %r, %r, %r)' % (self.short, self.long,
self.argcount, self.value)
class Required(ParentPattern):
def match(self, left, collected=None):
collected = [] if collected is None else collected
l = left
c = collected
for p in self.children:
matched, l, c = p.match(l, c)
if not matched:
return False, left, collected
return True, l, c
class Optional(ParentPattern):
def match(self, left, collected=None):
collected = [] if collected is None else collected
for p in self.children:
m, left, collected = p.match(left, collected)
return True, left, collected
class AnyOptions(Optional):
"""Marker/placeholder for [options] shortcut."""
class OneOrMore(ParentPattern):
def match(self, left, collected=None):
assert len(self.children) == 1
collected = [] if collected is None else collected
l = left
c = collected
l_ = None
matched = True
times = 0
while matched:
# could it be that something didn't match but changed l or c?
matched, l, c = self.children[0].match(l, c)
times += 1 if matched else 0
if l_ == l:
break
l_ = l
if times >= 1:
return True, l, c
return False, left, collected
class Either(ParentPattern):
def match(self, left, collected=None):
collected = [] if collected is None else collected
outcomes = []
for p in self.children:
matched, _, _ = outcome = p.match(left, collected)
if matched:
outcomes.append(outcome)
if outcomes:
return min(outcomes, key=lambda outcome: len(outcome[1]))
return False, left, collected
class TokenStream(list):
def __init__(self, source, error):
self += source.split() if hasattr(source, 'split') else source
self.error = error
def move(self):
return self.pop(0) if len(self) else None
def current(self):
return self[0] if len(self) else None
def parse_long(tokens, options):
"""long ::= '--' chars [ ( ' ' | '=' ) chars ] ;"""
long, eq, value = tokens.move().partition('=')
assert long.startswith('--')
value = None if eq == value == '' else value
similar = [o for o in options if o.long == long]
if tokens.error is DocoptExit and similar == []: # if no exact match
similar = [o for o in options if o.long and o.long.startswith(long)]
if len(similar) > 1: # might be simply specified ambiguously 2+ times?
raise tokens.error('%s is not a unique prefix: %s?' %
(long, ', '.join(o.long for o in similar)))
elif len(similar) < 1:
argcount = 1 if eq == '=' else 0
o = Option(None, long, argcount)
options.append(o)
if tokens.error is DocoptExit:
o = Option(None, long, argcount, value if argcount else True)
else:
o = Option(similar[0].short, similar[0].long,
similar[0].argcount, similar[0].value)
if o.argcount == 0:
if value is not None:
raise tokens.error('%s must not have an argument' % o.long)
else:
if value is None:
if tokens.current() is None:
raise tokens.error('%s requires argument' % o.long)
value = tokens.move()
if tokens.error is DocoptExit:
o.value = value if value is not None else True
return [o]
def parse_shorts(tokens, options):
"""shorts ::= '-' ( chars )* [ [ ' ' ] chars ] ;"""
token = tokens.move()
assert token.startswith('-') and not token.startswith('--')
left = token.lstrip('-')
parsed = []
while left != '':
short, left = '-' + left[0], left[1:]
similar = [o for o in options if o.short == short]
if len(similar) > 1:
raise tokens.error('%s is specified ambiguously %d times' %
(short, len(similar)))
elif len(similar) < 1:
o = Option(short, None, 0)
options.append(o)
if tokens.error is DocoptExit:
o = Option(short, None, 0, True)
else: # why copying is necessary here?
o = Option(short, similar[0].long,
similar[0].argcount, similar[0].value)
value = None
if o.argcount != 0:
if left == '':
if tokens.current() is None:
raise tokens.error('%s requires argument' % short)
value = tokens.move()
else:
value = left
left = ''
if tokens.error is DocoptExit:
o.value = value if value is not None else True
parsed.append(o)
return parsed
def parse_pattern(source, options):
tokens = TokenStream(re.sub(r'([\[\]\(\)\|]|\.\.\.)', r' \1 ', source),
DocoptLanguageError)
result = parse_expr(tokens, options)
if tokens.current() is not None:
raise tokens.error('unexpected ending: %r' % ' '.join(tokens))
return Required(*result)
def parse_expr(tokens, options):
"""expr ::= seq ( '|' seq )* ;"""
seq = parse_seq(tokens, options)
if tokens.current() != '|':
return seq
result = [Required(*seq)] if len(seq) > 1 else seq
while tokens.current() == '|':
tokens.move()
seq = parse_seq(tokens, options)
result += [Required(*seq)] if len(seq) > 1 else seq
return [Either(*result)] if len(result) > 1 else result
def parse_seq(tokens, options):
"""seq ::= ( atom [ '...' ] )* ;"""
result = []
while tokens.current() not in [None, ']', ')', '|']:
atom = parse_atom(tokens, options)
if tokens.current() == '...':
atom = [OneOrMore(*atom)]
tokens.move()
result += atom
return result
def parse_atom(tokens, options):
"""atom ::= '(' expr ')' | '[' expr ']' | 'options'
| long | shorts | argument | command ;
"""
token = tokens.current()
result = []
if token in '([':
tokens.move()
matching, pattern = {'(': [')', Required], '[': [']', Optional]}[token]
result = pattern(*parse_expr(tokens, options))
if tokens.move() != matching:
raise tokens.error("unmatched '%s'" % token)
return [result]
elif token == 'options':
tokens.move()
return [AnyOptions()]
elif token.startswith('--') and token != '--':
return parse_long(tokens, options)
elif token.startswith('-') and token not in ('-', '--'):
return parse_shorts(tokens, options)
elif token.startswith('<') and token.endswith('>') or token.isupper():
return [Argument(tokens.move())]
else:
return [Command(tokens.move())]
def parse_argv(tokens, options, options_first=False):
"""Parse command-line argument vector.
If options_first:
argv ::= [ long | shorts ]* [ argument ]* [ '--' [ argument ]* ] ;
else:
argv ::= [ long | shorts | argument ]* [ '--' [ argument ]* ] ;
"""
parsed = []
while tokens.current() is not None:
if tokens.current() == '--':
return parsed + [Argument(None, v) for v in tokens]
elif tokens.current().startswith('--'):
parsed += parse_long(tokens, options)
elif tokens.current().startswith('-') and tokens.current() != '-':
parsed += parse_shorts(tokens, options)
elif options_first:
return parsed + [Argument(None, v) for v in tokens]
else:
parsed.append(Argument(None, tokens.move()))
return parsed
def parse_defaults(doc):
# in python < 2.7 you can't pass flags=re.MULTILINE
split = re.split('\n *(<\S+?>|-\S+?)', doc)[1:]
split = [s1 + s2 for s1, s2 in zip(split[::2], split[1::2])]
options = [Option.parse(s) for s in split if s.startswith('-')]
#arguments = [Argument.parse(s) for s in split if s.startswith('<')]
#return options, arguments
return options
def printable_usage(doc):
# in python < 2.7 you can't pass flags=re.IGNORECASE
usage_split = re.split(r'([Uu][Ss][Aa][Gg][Ee]:)', doc)
if len(usage_split) < 3:
raise DocoptLanguageError('"usage:" (case-insensitive) not found.')
if len(usage_split) > 3:
raise DocoptLanguageError('More than one "usage:" (case-insensitive).')
return re.split(r'\n\s*\n', ''.join(usage_split[1:]))[0].strip()
def formal_usage(printable_usage):
pu = printable_usage.split()[1:] # split and drop "usage:"
return '( ' + ' '.join(') | (' if s == pu[0] else s for s in pu[1:]) + ' )'
def extras(help, version, options, doc):
if help and any((o.name in ('-h', '--help')) and o.value for o in options):
print(doc.strip("\n"))
sys.exit()
if version and any(o.name == '--version' and o.value for o in options):
print(version)
sys.exit()
class Dict(dict):
def __repr__(self):
return '{%s}' % ',\n '.join('%r: %r' % i for i in sorted(self.items()))
def docopt(doc, argv=None, help=True, version=None, options_first=False):
"""Parse `argv` based on command-line interface described in `doc`.
`docopt` creates your command-line interface based on its
description that you pass as `doc`. Such description can contain
--options, <positional-argument>, commands, which could be
[optional], (required), (mutually | exclusive) or repeated...
Parameters
----------
doc : str
Description of your command-line interface.
argv : list of str, optional
Argument vector to be parsed. sys.argv[1:] is used if not
provided.
help : bool (default: True)
Set to False to disable automatic help on -h or --help
options.
version : any object
If passed, the object will be printed if --version is in
`argv`.
options_first : bool (default: False)
Set to True to require options preceed positional arguments,
i.e. to forbid options and positional arguments intermix.
Returns
-------
args : dict
A dictionary, where keys are names of command-line elements
such as e.g. "--verbose" and "<path>", and values are the
parsed values of those elements.
Example
-------
>>> from docopt import docopt
>>> doc = '''
Usage:
my_program tcp <host> <port> [--timeout=<seconds>]
my_program serial <port> [--baud=<n>] [--timeout=<seconds>]
my_program (-h | --help | --version)
Options:
-h, --help Show this screen and exit.
--baud=<n> Baudrate [default: 9600]
'''
>>> argv = ['tcp', '127.0.0.1', '80', '--timeout', '30']
>>> docopt(doc, argv)
{'--baud': '9600',
'--help': False,
'--timeout': '30',
'--version': False,
'<host>': '127.0.0.1',
'<port>': '80',
'serial': False,
'tcp': True}
See also
--------
* For video introduction see http://docopt.org
* Full documentation is available in README.rst as well as online
at https://github.com/docopt/docopt#readme
"""
if argv is None:
argv = sys.argv[1:]
DocoptExit.usage = printable_usage(doc)
options = parse_defaults(doc)
pattern = parse_pattern(formal_usage(DocoptExit.usage), options)
# [default] syntax for argument is disabled
#for a in pattern.flat(Argument):
# same_name = [d for d in arguments if d.name == a.name]
# if same_name:
# a.value = same_name[0].value
argv = parse_argv(TokenStream(argv, DocoptExit), list(options),
options_first)
pattern_options = set(pattern.flat(Option))
for ao in pattern.flat(AnyOptions):
doc_options = parse_defaults(doc)
ao.children = list(set(doc_options) - pattern_options)
#if any_options:
# ao.children += [Option(o.short, o.long, o.argcount)
# for o in argv if type(o) is Option]
extras(help, version, argv, doc)
matched, left, collected = pattern.fix().match(argv)
if matched and left == []: # better error message if left?
return Dict((a.name, a.value) for a in (pattern.flat() + collected))
raise DocoptExit()
"""Run the EasyInstall command"""
if __name__ == '__main__':
from setuptools.command.easy_install import main
main()
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('mpl_toolkits',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('mpl_toolkits', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('mpl_toolkits', [os.path.dirname(p)])));m = m or sys.modules.setdefault('mpl_toolkits', types.ModuleType('mpl_toolkits'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)
from matplotlib.pylab import *
import matplotlib.pylab
__doc__ = matplotlib.pylab.__doc__
This source diff could not be displayed because it is too large. You can view the blob instead.
# Copyright (c) 2010-2020 Benjamin Peterson
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Utilities for writing code that runs on Python 2 and 3"""
from __future__ import absolute_import
import functools
import itertools
import operator
import sys
import types
__author__ = "Benjamin Peterson <benjamin@python.org>"
__version__ = "1.16.0"
# Useful for very coarse version differentiation.
PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
PY34 = sys.version_info[0:2] >= (3, 4)
if PY3:
string_types = str,
integer_types = int,
class_types = type,
text_type = str
binary_type = bytes
MAXSIZE = sys.maxsize
else:
string_types = basestring,
integer_types = (int, long)
class_types = (type, types.ClassType)
text_type = unicode
binary_type = str
if sys.platform.startswith("java"):
# Jython always uses 32 bits.
MAXSIZE = int((1 << 31) - 1)
else:
# It's possible to have sizeof(long) != sizeof(Py_ssize_t).
class X(object):
def __len__(self):
return 1 << 31
try:
len(X())
except OverflowError:
# 32-bit
MAXSIZE = int((1 << 31) - 1)
else:
# 64-bit
MAXSIZE = int((1 << 63) - 1)
del X
if PY34:
from importlib.util import spec_from_loader
else:
spec_from_loader = None
def _add_doc(func, doc):
"""Add documentation to a function."""
func.__doc__ = doc
def _import_module(name):
"""Import module, returning the module after the last dot."""
__import__(name)
return sys.modules[name]
class _LazyDescr(object):
def __init__(self, name):
self.name = name
def __get__(self, obj, tp):
result = self._resolve()
setattr(obj, self.name, result) # Invokes __set__.
try:
# This is a bit ugly, but it avoids running this again by
# removing this descriptor.
delattr(obj.__class__, self.name)
except AttributeError:
pass
return result
class MovedModule(_LazyDescr):
def __init__(self, name, old, new=None):
super(MovedModule, self).__init__(name)
if PY3:
if new is None:
new = name
self.mod = new
else:
self.mod = old
def _resolve(self):
return _import_module(self.mod)
def __getattr__(self, attr):
_module = self._resolve()
value = getattr(_module, attr)
setattr(self, attr, value)
return value
class _LazyModule(types.ModuleType):
def __init__(self, name):
super(_LazyModule, self).__init__(name)
self.__doc__ = self.__class__.__doc__
def __dir__(self):
attrs = ["__doc__", "__name__"]
attrs += [attr.name for attr in self._moved_attributes]
return attrs
# Subclasses should override this
_moved_attributes = []
class MovedAttribute(_LazyDescr):
def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None):
super(MovedAttribute, self).__init__(name)
if PY3:
if new_mod is None:
new_mod = name
self.mod = new_mod
if new_attr is None:
if old_attr is None:
new_attr = name
else:
new_attr = old_attr
self.attr = new_attr
else:
self.mod = old_mod
if old_attr is None:
old_attr = name
self.attr = old_attr
def _resolve(self):
module = _import_module(self.mod)
return getattr(module, self.attr)
class _SixMetaPathImporter(object):
"""
A meta path importer to import six.moves and its submodules.
This class implements a PEP302 finder and loader. It should be compatible
with Python 2.5 and all existing versions of Python3
"""
def __init__(self, six_module_name):
self.name = six_module_name
self.known_modules = {}
def _add_module(self, mod, *fullnames):
for fullname in fullnames:
self.known_modules[self.name + "." + fullname] = mod
def _get_module(self, fullname):
return self.known_modules[self.name + "." + fullname]
def find_module(self, fullname, path=None):
if fullname in self.known_modules:
return self
return None
def find_spec(self, fullname, path, target=None):
if fullname in self.known_modules:
return spec_from_loader(fullname, self)
return None
def __get_module(self, fullname):
try:
return self.known_modules[fullname]
except KeyError:
raise ImportError("This loader does not know module " + fullname)
def load_module(self, fullname):
try:
# in case of a reload
return sys.modules[fullname]
except KeyError:
pass
mod = self.__get_module(fullname)
if isinstance(mod, MovedModule):
mod = mod._resolve()
else:
mod.__loader__ = self
sys.modules[fullname] = mod
return mod
def is_package(self, fullname):
"""
Return true, if the named module is a package.
We need this method to get correct spec objects with
Python 3.4 (see PEP451)
"""
return hasattr(self.__get_module(fullname), "__path__")
def get_code(self, fullname):
"""Return None
Required, if is_package is implemented"""
self.__get_module(fullname) # eventually raises ImportError
return None
get_source = get_code # same as get_code
def create_module(self, spec):
return self.load_module(spec.name)
def exec_module(self, module):
pass
_importer = _SixMetaPathImporter(__name__)
class _MovedItems(_LazyModule):
"""Lazy loading of moved objects"""
__path__ = [] # mark as package
_moved_attributes = [
MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"),
MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"),
MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"),
MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"),
MovedAttribute("intern", "__builtin__", "sys"),
MovedAttribute("map", "itertools", "builtins", "imap", "map"),
MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"),
MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"),
MovedAttribute("getoutput", "commands", "subprocess"),
MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"),
MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"),
MovedAttribute("reduce", "__builtin__", "functools"),
MovedAttribute("shlex_quote", "pipes", "shlex", "quote"),
MovedAttribute("StringIO", "StringIO", "io"),
MovedAttribute("UserDict", "UserDict", "collections"),
MovedAttribute("UserList", "UserList", "collections"),
MovedAttribute("UserString", "UserString", "collections"),
MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"),
MovedAttribute("zip", "itertools", "builtins", "izip", "zip"),
MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"),
MovedModule("builtins", "__builtin__"),
MovedModule("configparser", "ConfigParser"),
MovedModule("collections_abc", "collections", "collections.abc" if sys.version_info >= (3, 3) else "collections"),
MovedModule("copyreg", "copy_reg"),
MovedModule("dbm_gnu", "gdbm", "dbm.gnu"),
MovedModule("dbm_ndbm", "dbm", "dbm.ndbm"),
MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread" if sys.version_info < (3, 9) else "_thread"),
MovedModule("http_cookiejar", "cookielib", "http.cookiejar"),
MovedModule("http_cookies", "Cookie", "http.cookies"),
MovedModule("html_entities", "htmlentitydefs", "html.entities"),
MovedModule("html_parser", "HTMLParser", "html.parser"),
MovedModule("http_client", "httplib", "http.client"),
MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"),
MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"),
MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"),
MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"),
MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"),
MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"),
MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"),
MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"),
MovedModule("cPickle", "cPickle", "pickle"),
MovedModule("queue", "Queue"),
MovedModule("reprlib", "repr"),
MovedModule("socketserver", "SocketServer"),
MovedModule("_thread", "thread", "_thread"),
MovedModule("tkinter", "Tkinter"),
MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"),
MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"),
MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"),
MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"),
MovedModule("tkinter_tix", "Tix", "tkinter.tix"),
MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"),
MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"),
MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"),
MovedModule("tkinter_colorchooser", "tkColorChooser",
"tkinter.colorchooser"),
MovedModule("tkinter_commondialog", "tkCommonDialog",
"tkinter.commondialog"),
MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"),
MovedModule("tkinter_font", "tkFont", "tkinter.font"),
MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"),
MovedModule("tkinter_tksimpledialog", "tkSimpleDialog",
"tkinter.simpledialog"),
MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"),
MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"),
MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"),
MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"),
MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"),
MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"),
]
# Add windows specific modules.
if sys.platform == "win32":
_moved_attributes += [
MovedModule("winreg", "_winreg"),
]
for attr in _moved_attributes:
setattr(_MovedItems, attr.name, attr)
if isinstance(attr, MovedModule):
_importer._add_module(attr, "moves." + attr.name)
del attr
_MovedItems._moved_attributes = _moved_attributes
moves = _MovedItems(__name__ + ".moves")
_importer._add_module(moves, "moves")
class Module_six_moves_urllib_parse(_LazyModule):
"""Lazy loading of moved objects in six.moves.urllib_parse"""
_urllib_parse_moved_attributes = [
MovedAttribute("ParseResult", "urlparse", "urllib.parse"),
MovedAttribute("SplitResult", "urlparse", "urllib.parse"),
MovedAttribute("parse_qs", "urlparse", "urllib.parse"),
MovedAttribute("parse_qsl", "urlparse", "urllib.parse"),
MovedAttribute("urldefrag", "urlparse", "urllib.parse"),
MovedAttribute("urljoin", "urlparse", "urllib.parse"),
MovedAttribute("urlparse", "urlparse", "urllib.parse"),
MovedAttribute("urlsplit", "urlparse", "urllib.parse"),
MovedAttribute("urlunparse", "urlparse", "urllib.parse"),
MovedAttribute("urlunsplit", "urlparse", "urllib.parse"),
MovedAttribute("quote", "urllib", "urllib.parse"),
MovedAttribute("quote_plus", "urllib", "urllib.parse"),
MovedAttribute("unquote", "urllib", "urllib.parse"),
MovedAttribute("unquote_plus", "urllib", "urllib.parse"),
MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"),
MovedAttribute("urlencode", "urllib", "urllib.parse"),
MovedAttribute("splitquery", "urllib", "urllib.parse"),
MovedAttribute("splittag", "urllib", "urllib.parse"),
MovedAttribute("splituser", "urllib", "urllib.parse"),
MovedAttribute("splitvalue", "urllib", "urllib.parse"),
MovedAttribute("uses_fragment", "urlparse", "urllib.parse"),
MovedAttribute("uses_netloc", "urlparse", "urllib.parse"),
MovedAttribute("uses_params", "urlparse", "urllib.parse"),
MovedAttribute("uses_query", "urlparse", "urllib.parse"),
MovedAttribute("uses_relative", "urlparse", "urllib.parse"),
]
for attr in _urllib_parse_moved_attributes:
setattr(Module_six_moves_urllib_parse, attr.name, attr)
del attr
Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes
_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"),
"moves.urllib_parse", "moves.urllib.parse")
class Module_six_moves_urllib_error(_LazyModule):
"""Lazy loading of moved objects in six.moves.urllib_error"""
_urllib_error_moved_attributes = [
MovedAttribute("URLError", "urllib2", "urllib.error"),
MovedAttribute("HTTPError", "urllib2", "urllib.error"),
MovedAttribute("ContentTooShortError", "urllib", "urllib.error"),
]
for attr in _urllib_error_moved_attributes:
setattr(Module_six_moves_urllib_error, attr.name, attr)
del attr
Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes
_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"),
"moves.urllib_error", "moves.urllib.error")
class Module_six_moves_urllib_request(_LazyModule):
"""Lazy loading of moved objects in six.moves.urllib_request"""
_urllib_request_moved_attributes = [
MovedAttribute("urlopen", "urllib2", "urllib.request"),
MovedAttribute("install_opener", "urllib2", "urllib.request"),
MovedAttribute("build_opener", "urllib2", "urllib.request"),
MovedAttribute("pathname2url", "urllib", "urllib.request"),
MovedAttribute("url2pathname", "urllib", "urllib.request"),
MovedAttribute("getproxies", "urllib", "urllib.request"),
MovedAttribute("Request", "urllib2", "urllib.request"),
MovedAttribute("OpenerDirector", "urllib2", "urllib.request"),
MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"),
MovedAttribute("ProxyHandler", "urllib2", "urllib.request"),
MovedAttribute("BaseHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"),
MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"),
MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"),
MovedAttribute("FileHandler", "urllib2", "urllib.request"),
MovedAttribute("FTPHandler", "urllib2", "urllib.request"),
MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"),
MovedAttribute("UnknownHandler", "urllib2", "urllib.request"),
MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"),
MovedAttribute("urlretrieve", "urllib", "urllib.request"),
MovedAttribute("urlcleanup", "urllib", "urllib.request"),
MovedAttribute("URLopener", "urllib", "urllib.request"),
MovedAttribute("FancyURLopener", "urllib", "urllib.request"),
MovedAttribute("proxy_bypass", "urllib", "urllib.request"),
MovedAttribute("parse_http_list", "urllib2", "urllib.request"),
MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"),
]
for attr in _urllib_request_moved_attributes:
setattr(Module_six_moves_urllib_request, attr.name, attr)
del attr
Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes
_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"),
"moves.urllib_request", "moves.urllib.request")
class Module_six_moves_urllib_response(_LazyModule):
"""Lazy loading of moved objects in six.moves.urllib_response"""
_urllib_response_moved_attributes = [
MovedAttribute("addbase", "urllib", "urllib.response"),
MovedAttribute("addclosehook", "urllib", "urllib.response"),
MovedAttribute("addinfo", "urllib", "urllib.response"),
MovedAttribute("addinfourl", "urllib", "urllib.response"),
]
for attr in _urllib_response_moved_attributes:
setattr(Module_six_moves_urllib_response, attr.name, attr)
del attr
Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes
_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"),
"moves.urllib_response", "moves.urllib.response")
class Module_six_moves_urllib_robotparser(_LazyModule):
"""Lazy loading of moved objects in six.moves.urllib_robotparser"""
_urllib_robotparser_moved_attributes = [
MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"),
]
for attr in _urllib_robotparser_moved_attributes:
setattr(Module_six_moves_urllib_robotparser, attr.name, attr)
del attr
Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes
_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"),
"moves.urllib_robotparser", "moves.urllib.robotparser")
class Module_six_moves_urllib(types.ModuleType):
"""Create a six.moves.urllib namespace that resembles the Python 3 namespace"""
__path__ = [] # mark as package
parse = _importer._get_module("moves.urllib_parse")
error = _importer._get_module("moves.urllib_error")
request = _importer._get_module("moves.urllib_request")
response = _importer._get_module("moves.urllib_response")
robotparser = _importer._get_module("moves.urllib_robotparser")
def __dir__(self):
return ['parse', 'error', 'request', 'response', 'robotparser']
_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"),
"moves.urllib")
def add_move(move):
"""Add an item to six.moves."""
setattr(_MovedItems, move.name, move)
def remove_move(name):
"""Remove item from six.moves."""
try:
delattr(_MovedItems, name)
except AttributeError:
try:
del moves.__dict__[name]
except KeyError:
raise AttributeError("no such move, %r" % (name,))
if PY3:
_meth_func = "__func__"
_meth_self = "__self__"
_func_closure = "__closure__"
_func_code = "__code__"
_func_defaults = "__defaults__"
_func_globals = "__globals__"
else:
_meth_func = "im_func"
_meth_self = "im_self"
_func_closure = "func_closure"
_func_code = "func_code"
_func_defaults = "func_defaults"
_func_globals = "func_globals"
try:
advance_iterator = next
except NameError:
def advance_iterator(it):
return it.next()
next = advance_iterator
try:
callable = callable
except NameError:
def callable(obj):
return any("__call__" in klass.__dict__ for klass in type(obj).__mro__)
if PY3:
def get_unbound_function(unbound):
return unbound
create_bound_method = types.MethodType
def create_unbound_method(func, cls):
return func
Iterator = object
else:
def get_unbound_function(unbound):
return unbound.im_func
def create_bound_method(func, obj):
return types.MethodType(func, obj, obj.__class__)
def create_unbound_method(func, cls):
return types.MethodType(func, None, cls)
class Iterator(object):
def next(self):
return type(self).__next__(self)
callable = callable
_add_doc(get_unbound_function,
"""Get the function out of a possibly unbound function""")
get_method_function = operator.attrgetter(_meth_func)
get_method_self = operator.attrgetter(_meth_self)
get_function_closure = operator.attrgetter(_func_closure)
get_function_code = operator.attrgetter(_func_code)
get_function_defaults = operator.attrgetter(_func_defaults)
get_function_globals = operator.attrgetter(_func_globals)
if PY3:
def iterkeys(d, **kw):
return iter(d.keys(**kw))
def itervalues(d, **kw):
return iter(d.values(**kw))
def iteritems(d, **kw):
return iter(d.items(**kw))
def iterlists(d, **kw):
return iter(d.lists(**kw))
viewkeys = operator.methodcaller("keys")
viewvalues = operator.methodcaller("values")
viewitems = operator.methodcaller("items")
else:
def iterkeys(d, **kw):
return d.iterkeys(**kw)
def itervalues(d, **kw):
return d.itervalues(**kw)
def iteritems(d, **kw):
return d.iteritems(**kw)
def iterlists(d, **kw):
return d.iterlists(**kw)
viewkeys = operator.methodcaller("viewkeys")
viewvalues = operator.methodcaller("viewvalues")
viewitems = operator.methodcaller("viewitems")
_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.")
_add_doc(itervalues, "Return an iterator over the values of a dictionary.")
_add_doc(iteritems,
"Return an iterator over the (key, value) pairs of a dictionary.")
_add_doc(iterlists,
"Return an iterator over the (key, [values]) pairs of a dictionary.")
if PY3:
def b(s):
return s.encode("latin-1")
def u(s):
return s
unichr = chr
import struct
int2byte = struct.Struct(">B").pack
del struct
byte2int = operator.itemgetter(0)
indexbytes = operator.getitem
iterbytes = iter
import io
StringIO = io.StringIO
BytesIO = io.BytesIO
del io
_assertCountEqual = "assertCountEqual"
if sys.version_info[1] <= 1:
_assertRaisesRegex = "assertRaisesRegexp"
_assertRegex = "assertRegexpMatches"
_assertNotRegex = "assertNotRegexpMatches"
else:
_assertRaisesRegex = "assertRaisesRegex"
_assertRegex = "assertRegex"
_assertNotRegex = "assertNotRegex"
else:
def b(s):
return s
# Workaround for standalone backslash
def u(s):
return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape")
unichr = unichr
int2byte = chr
def byte2int(bs):
return ord(bs[0])
def indexbytes(buf, i):
return ord(buf[i])
iterbytes = functools.partial(itertools.imap, ord)
import StringIO
StringIO = BytesIO = StringIO.StringIO
_assertCountEqual = "assertItemsEqual"
_assertRaisesRegex = "assertRaisesRegexp"
_assertRegex = "assertRegexpMatches"
_assertNotRegex = "assertNotRegexpMatches"
_add_doc(b, """Byte literal""")
_add_doc(u, """Text literal""")
def assertCountEqual(self, *args, **kwargs):
return getattr(self, _assertCountEqual)(*args, **kwargs)
def assertRaisesRegex(self, *args, **kwargs):
return getattr(self, _assertRaisesRegex)(*args, **kwargs)
def assertRegex(self, *args, **kwargs):
return getattr(self, _assertRegex)(*args, **kwargs)
def assertNotRegex(self, *args, **kwargs):
return getattr(self, _assertNotRegex)(*args, **kwargs)
if PY3:
exec_ = getattr(moves.builtins, "exec")
def reraise(tp, value, tb=None):
try:
if value is None:
value = tp()
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
finally:
value = None
tb = None
else:
def exec_(_code_, _globs_=None, _locs_=None):
"""Execute code in a namespace."""
if _globs_ is None:
frame = sys._getframe(1)
_globs_ = frame.f_globals
if _locs_ is None:
_locs_ = frame.f_locals
del frame
elif _locs_ is None:
_locs_ = _globs_
exec("""exec _code_ in _globs_, _locs_""")
exec_("""def reraise(tp, value, tb=None):
try:
raise tp, value, tb
finally:
tb = None
""")
if sys.version_info[:2] > (3,):
exec_("""def raise_from(value, from_value):
try:
raise value from from_value
finally:
value = None
""")
else:
def raise_from(value, from_value):
raise value
print_ = getattr(moves.builtins, "print", None)
if print_ is None:
def print_(*args, **kwargs):
"""The new-style print function for Python 2.4 and 2.5."""
fp = kwargs.pop("file", sys.stdout)
if fp is None:
return
def write(data):
if not isinstance(data, basestring):
data = str(data)
# If the file has an encoding, encode unicode with it.
if (isinstance(fp, file) and
isinstance(data, unicode) and
fp.encoding is not None):
errors = getattr(fp, "errors", None)
if errors is None:
errors = "strict"
data = data.encode(fp.encoding, errors)
fp.write(data)
want_unicode = False
sep = kwargs.pop("sep", None)
if sep is not None:
if isinstance(sep, unicode):
want_unicode = True
elif not isinstance(sep, str):
raise TypeError("sep must be None or a string")
end = kwargs.pop("end", None)
if end is not None:
if isinstance(end, unicode):
want_unicode = True
elif not isinstance(end, str):
raise TypeError("end must be None or a string")
if kwargs:
raise TypeError("invalid keyword arguments to print()")
if not want_unicode:
for arg in args:
if isinstance(arg, unicode):
want_unicode = True
break
if want_unicode:
newline = unicode("\n")
space = unicode(" ")
else:
newline = "\n"
space = " "
if sep is None:
sep = space
if end is None:
end = newline
for i, arg in enumerate(args):
if i:
write(sep)
write(arg)
write(end)
if sys.version_info[:2] < (3, 3):
_print = print_
def print_(*args, **kwargs):
fp = kwargs.get("file", sys.stdout)
flush = kwargs.pop("flush", False)
_print(*args, **kwargs)
if flush and fp is not None:
fp.flush()
_add_doc(reraise, """Reraise an exception.""")
if sys.version_info[0:2] < (3, 4):
# This does exactly the same what the :func:`py3:functools.update_wrapper`
# function does on Python versions after 3.2. It sets the ``__wrapped__``
# attribute on ``wrapper`` object and it doesn't raise an error if any of
# the attributes mentioned in ``assigned`` and ``updated`` are missing on
# ``wrapped`` object.
def _update_wrapper(wrapper, wrapped,
assigned=functools.WRAPPER_ASSIGNMENTS,
updated=functools.WRAPPER_UPDATES):
for attr in assigned:
try:
value = getattr(wrapped, attr)
except AttributeError:
continue
else:
setattr(wrapper, attr, value)
for attr in updated:
getattr(wrapper, attr).update(getattr(wrapped, attr, {}))
wrapper.__wrapped__ = wrapped
return wrapper
_update_wrapper.__doc__ = functools.update_wrapper.__doc__
def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS,
updated=functools.WRAPPER_UPDATES):
return functools.partial(_update_wrapper, wrapped=wrapped,
assigned=assigned, updated=updated)
wraps.__doc__ = functools.wraps.__doc__
else:
wraps = functools.wraps
def with_metaclass(meta, *bases):
"""Create a base class with a metaclass."""
# This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(type):
def __new__(cls, name, this_bases, d):
if sys.version_info[:2] >= (3, 7):
# This version introduced PEP 560 that requires a bit
# of extra care (we mimic what is done by __build_class__).
resolved_bases = types.resolve_bases(bases)
if resolved_bases is not bases:
d['__orig_bases__'] = bases
else:
resolved_bases = bases
return meta(name, resolved_bases, d)
@classmethod
def __prepare__(cls, name, this_bases):
return meta.__prepare__(name, bases)
return type.__new__(metaclass, 'temporary_class', (), {})
def add_metaclass(metaclass):
"""Class decorator for creating a class with a metaclass."""
def wrapper(cls):
orig_vars = cls.__dict__.copy()
slots = orig_vars.get('__slots__')
if slots is not None:
if isinstance(slots, str):
slots = [slots]
for slots_var in slots:
orig_vars.pop(slots_var)
orig_vars.pop('__dict__', None)
orig_vars.pop('__weakref__', None)
if hasattr(cls, '__qualname__'):
orig_vars['__qualname__'] = cls.__qualname__
return metaclass(cls.__name__, cls.__bases__, orig_vars)
return wrapper
def ensure_binary(s, encoding='utf-8', errors='strict'):
"""Coerce **s** to six.binary_type.
For Python 2:
- `unicode` -> encoded to `str`
- `str` -> `str`
For Python 3:
- `str` -> encoded to `bytes`
- `bytes` -> `bytes`
"""
if isinstance(s, binary_type):
return s
if isinstance(s, text_type):
return s.encode(encoding, errors)
raise TypeError("not expecting type '%s'" % type(s))
def ensure_str(s, encoding='utf-8', errors='strict'):
"""Coerce *s* to `str`.
For Python 2:
- `unicode` -> encoded to `str`
- `str` -> `str`
For Python 3:
- `str` -> `str`
- `bytes` -> decoded to `str`
"""
# Optimization: Fast return for the common case.
if type(s) is str:
return s
if PY2 and isinstance(s, text_type):
return s.encode(encoding, errors)
elif PY3 and isinstance(s, binary_type):
return s.decode(encoding, errors)
elif not isinstance(s, (text_type, binary_type)):
raise TypeError("not expecting type '%s'" % type(s))
return s
def ensure_text(s, encoding='utf-8', errors='strict'):
"""Coerce *s* to six.text_type.
For Python 2:
- `unicode` -> `unicode`
- `str` -> `unicode`
For Python 3:
- `str` -> `str`
- `bytes` -> decoded to `str`
"""
if isinstance(s, binary_type):
return s.decode(encoding, errors)
elif isinstance(s, text_type):
return s
else:
raise TypeError("not expecting type '%s'" % type(s))
def python_2_unicode_compatible(klass):
"""
A class decorator that defines __unicode__ and __str__ methods under Python 2.
Under Python 3 it does nothing.
To support Python 2 and 3 with a single code base, define a __str__ method
returning text and apply this decorator to the class.
"""
if PY2:
if '__str__' not in klass.__dict__:
raise ValueError("@python_2_unicode_compatible cannot be applied "
"to %s because it doesn't define __str__()." %
klass.__name__)
klass.__unicode__ = klass.__str__
klass.__str__ = lambda self: self.__unicode__().encode('utf-8')
return klass
# Complete the moves implementation.
# This code is at the end of this module to speed up module loading.
# Turn this module into a package.
__path__ = [] # required for PEP 302 and PEP 451
__package__ = __name__ # see PEP 366 @ReservedAssignment
if globals().get("__spec__") is not None:
__spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable
# Remove other six meta path importers, since they cause problems. This can
# happen if six is removed from sys.modules and then reloaded. (Setuptools does
# this for some reason.)
if sys.meta_path:
for i, importer in enumerate(sys.meta_path):
# Here's some real nastiness: Another "instance" of the six module might
# be floating around. Therefore, we can't use isinstance() to check for
# the six meta path importer, since the other six instance will have
# inserted an importer with different class.
if (type(importer).__name__ == "_SixMetaPathImporter" and
importer.name == __name__):
del sys.meta_path[i]
break
del i, importer
# Finally, add the importer to the meta path import hook.
sys.meta_path.append(_importer)
# -*- coding: utf-8 -*-
"""Pretty-print tabular data."""
from __future__ import print_function
from __future__ import unicode_literals
from collections import namedtuple
import sys
import re
import math
if sys.version_info >= (3, 3):
from collections.abc import Iterable
else:
from collections import Iterable
if sys.version_info[0] < 3:
from itertools import izip_longest
from functools import partial
_none_type = type(None)
_bool_type = bool
_int_type = int
_long_type = long # noqa
_float_type = float
_text_type = unicode # noqa
_binary_type = str
def _is_file(f):
return hasattr(f, "read")
else:
from itertools import zip_longest as izip_longest
from functools import reduce, partial
_none_type = type(None)
_bool_type = bool
_int_type = int
_long_type = int
_float_type = float
_text_type = str
_binary_type = bytes
basestring = str
import io
def _is_file(f):
return isinstance(f, io.IOBase)
try:
import wcwidth # optional wide-character (CJK) support
except ImportError:
wcwidth = None
try:
from html import escape as htmlescape
except ImportError:
from cgi import escape as htmlescape
__all__ = ["tabulate", "tabulate_formats", "simple_separated_format"]
__version__ = "0.8.9"
# minimum extra space in headers
MIN_PADDING = 2
# Whether or not to preserve leading/trailing whitespace in data.
PRESERVE_WHITESPACE = False
_DEFAULT_FLOATFMT = "g"
_DEFAULT_MISSINGVAL = ""
# default align will be overwritten by "left", "center" or "decimal"
# depending on the formatter
_DEFAULT_ALIGN = "default"
# if True, enable wide-character (CJK) support
WIDE_CHARS_MODE = wcwidth is not None
Line = namedtuple("Line", ["begin", "hline", "sep", "end"])
DataRow = namedtuple("DataRow", ["begin", "sep", "end"])
# A table structure is suppposed to be:
#
# --- lineabove ---------
# headerrow
# --- linebelowheader ---
# datarow
# --- linebetweenrows ---
# ... (more datarows) ...
# --- linebetweenrows ---
# last datarow
# --- linebelow ---------
#
# TableFormat's line* elements can be
#
# - either None, if the element is not used,
# - or a Line tuple,
# - or a function: [col_widths], [col_alignments] -> string.
#
# TableFormat's *row elements can be
#
# - either None, if the element is not used,
# - or a DataRow tuple,
# - or a function: [cell_values], [col_widths], [col_alignments] -> string.
#
# padding (an integer) is the amount of white space around data values.
#
# with_header_hide:
#
# - either None, to display all table elements unconditionally,
# - or a list of elements not to be displayed if the table has column headers.
#
TableFormat = namedtuple(
"TableFormat",
[
"lineabove",
"linebelowheader",
"linebetweenrows",
"linebelow",
"headerrow",
"datarow",
"padding",
"with_header_hide",
],
)
def _pipe_segment_with_colons(align, colwidth):
"""Return a segment of a horizontal line with optional colons which
indicate column's alignment (as in `pipe` output format)."""
w = colwidth
if align in ["right", "decimal"]:
return ("-" * (w - 1)) + ":"
elif align == "center":
return ":" + ("-" * (w - 2)) + ":"
elif align == "left":
return ":" + ("-" * (w - 1))
else:
return "-" * w
def _pipe_line_with_colons(colwidths, colaligns):
"""Return a horizontal line with optional colons to indicate column's
alignment (as in `pipe` output format)."""
if not colaligns: # e.g. printing an empty data frame (github issue #15)
colaligns = [""] * len(colwidths)
segments = [_pipe_segment_with_colons(a, w) for a, w in zip(colaligns, colwidths)]
return "|" + "|".join(segments) + "|"
def _mediawiki_row_with_attrs(separator, cell_values, colwidths, colaligns):
alignment = {
"left": "",
"right": 'align="right"| ',
"center": 'align="center"| ',
"decimal": 'align="right"| ',
}
# hard-coded padding _around_ align attribute and value together
# rather than padding parameter which affects only the value
values_with_attrs = [
" " + alignment.get(a, "") + c + " " for c, a in zip(cell_values, colaligns)
]
colsep = separator * 2
return (separator + colsep.join(values_with_attrs)).rstrip()
def _textile_row_with_attrs(cell_values, colwidths, colaligns):
cell_values[0] += " "
alignment = {"left": "<.", "right": ">.", "center": "=.", "decimal": ">."}
values = (alignment.get(a, "") + v for a, v in zip(colaligns, cell_values))
return "|" + "|".join(values) + "|"
def _html_begin_table_without_header(colwidths_ignore, colaligns_ignore):
# this table header will be suppressed if there is a header row
return "<table>\n<tbody>"
def _html_row_with_attrs(celltag, unsafe, cell_values, colwidths, colaligns):
alignment = {
"left": "",
"right": ' style="text-align: right;"',
"center": ' style="text-align: center;"',
"decimal": ' style="text-align: right;"',
}
if unsafe:
values_with_attrs = [
"<{0}{1}>{2}</{0}>".format(celltag, alignment.get(a, ""), c)
for c, a in zip(cell_values, colaligns)
]
else:
values_with_attrs = [
"<{0}{1}>{2}</{0}>".format(celltag, alignment.get(a, ""), htmlescape(c))
for c, a in zip(cell_values, colaligns)
]
rowhtml = "<tr>{}</tr>".format("".join(values_with_attrs).rstrip())
if celltag == "th": # it's a header row, create a new table header
rowhtml = "<table>\n<thead>\n{}\n</thead>\n<tbody>".format(rowhtml)
return rowhtml
def _moin_row_with_attrs(celltag, cell_values, colwidths, colaligns, header=""):
alignment = {
"left": "",
"right": '<style="text-align: right;">',
"center": '<style="text-align: center;">',
"decimal": '<style="text-align: right;">',
}
values_with_attrs = [
"{0}{1} {2} ".format(celltag, alignment.get(a, ""), header + c + header)
for c, a in zip(cell_values, colaligns)
]
return "".join(values_with_attrs) + "||"
def _latex_line_begin_tabular(colwidths, colaligns, booktabs=False, longtable=False):
alignment = {"left": "l", "right": "r", "center": "c", "decimal": "r"}
tabular_columns_fmt = "".join([alignment.get(a, "l") for a in colaligns])
return "\n".join(
[
("\\begin{tabular}{" if not longtable else "\\begin{longtable}{")
+ tabular_columns_fmt
+ "}",
"\\toprule" if booktabs else "\\hline",
]
)
LATEX_ESCAPE_RULES = {
r"&": r"\&",
r"%": r"\%",
r"$": r"\$",
r"#": r"\#",
r"_": r"\_",
r"^": r"\^{}",
r"{": r"\{",
r"}": r"\}",
r"~": r"\textasciitilde{}",
"\\": r"\textbackslash{}",
r"<": r"\ensuremath{<}",
r">": r"\ensuremath{>}",
}
def _latex_row(cell_values, colwidths, colaligns, escrules=LATEX_ESCAPE_RULES):
def escape_char(c):
return escrules.get(c, c)
escaped_values = ["".join(map(escape_char, cell)) for cell in cell_values]
rowfmt = DataRow("", "&", "\\\\")
return _build_simple_row(escaped_values, rowfmt)
def _rst_escape_first_column(rows, headers):
def escape_empty(val):
if isinstance(val, (_text_type, _binary_type)) and not val.strip():
return ".."
else:
return val
new_headers = list(headers)
new_rows = []
if headers:
new_headers[0] = escape_empty(headers[0])
for row in rows:
new_row = list(row)
if new_row:
new_row[0] = escape_empty(row[0])
new_rows.append(new_row)
return new_rows, new_headers
_table_formats = {
"simple": TableFormat(
lineabove=Line("", "-", " ", ""),
linebelowheader=Line("", "-", " ", ""),
linebetweenrows=None,
linebelow=Line("", "-", " ", ""),
headerrow=DataRow("", " ", ""),
datarow=DataRow("", " ", ""),
padding=0,
with_header_hide=["lineabove", "linebelow"],
),
"plain": TableFormat(
lineabove=None,
linebelowheader=None,
linebetweenrows=None,
linebelow=None,
headerrow=DataRow("", " ", ""),
datarow=DataRow("", " ", ""),
padding=0,
with_header_hide=None,
),
"grid": TableFormat(
lineabove=Line("+", "-", "+", "+"),
linebelowheader=Line("+", "=", "+", "+"),
linebetweenrows=Line("+", "-", "+", "+"),
linebelow=Line("+", "-", "+", "+"),
headerrow=DataRow("|", "|", "|"),
datarow=DataRow("|", "|", "|"),
padding=1,
with_header_hide=None,
),
"fancy_grid": TableFormat(
lineabove=Line("╒", "═", "╤", "╕"),
linebelowheader=Line("╞", "═", "╪", "╡"),
linebetweenrows=Line("├", "─", "┼", "┤"),
linebelow=Line("╘", "═", "╧", "╛"),
headerrow=DataRow("│", "│", "│"),
datarow=DataRow("│", "│", "│"),
padding=1,
with_header_hide=None,
),
"fancy_outline": TableFormat(
lineabove=Line("╒", "═", "╤", "╕"),
linebelowheader=Line("╞", "═", "╪", "╡"),
linebetweenrows=None,
linebelow=Line("╘", "═", "╧", "╛"),
headerrow=DataRow("│", "│", "│"),
datarow=DataRow("│", "│", "│"),
padding=1,
with_header_hide=None,
),
"github": TableFormat(
lineabove=Line("|", "-", "|", "|"),
linebelowheader=Line("|", "-", "|", "|"),
linebetweenrows=None,
linebelow=None,
headerrow=DataRow("|", "|", "|"),
datarow=DataRow("|", "|", "|"),
padding=1,
with_header_hide=["lineabove"],
),
"pipe": TableFormat(
lineabove=_pipe_line_with_colons,
linebelowheader=_pipe_line_with_colons,
linebetweenrows=None,
linebelow=None,
headerrow=DataRow("|", "|", "|"),
datarow=DataRow("|", "|", "|"),
padding=1,
with_header_hide=["lineabove"],
),
"orgtbl": TableFormat(
lineabove=None,
linebelowheader=Line("|", "-", "+", "|"),
linebetweenrows=None,
linebelow=None,
headerrow=DataRow("|", "|", "|"),
datarow=DataRow("|", "|", "|"),
padding=1,
with_header_hide=None,
),
"jira": TableFormat(
lineabove=None,
linebelowheader=None,
linebetweenrows=None,
linebelow=None,
headerrow=DataRow("||", "||", "||"),
datarow=DataRow("|", "|", "|"),
padding=1,
with_header_hide=None,
),
"presto": TableFormat(
lineabove=None,
linebelowheader=Line("", "-", "+", ""),
linebetweenrows=None,
linebelow=None,
headerrow=DataRow("", "|", ""),
datarow=DataRow("", "|", ""),
padding=1,
with_header_hide=None,
),
"pretty": TableFormat(
lineabove=Line("+", "-", "+", "+"),
linebelowheader=Line("+", "-", "+", "+"),
linebetweenrows=None,
linebelow=Line("+", "-", "+", "+"),
headerrow=DataRow("|", "|", "|"),
datarow=DataRow("|", "|", "|"),
padding=1,
with_header_hide=None,
),
"psql": TableFormat(
lineabove=Line("+", "-", "+", "+"),
linebelowheader=Line("|", "-", "+", "|"),
linebetweenrows=None,
linebelow=Line("+", "-", "+", "+"),
headerrow=DataRow("|", "|", "|"),
datarow=DataRow("|", "|", "|"),
padding=1,
with_header_hide=None,
),
"rst": TableFormat(
lineabove=Line("", "=", " ", ""),
linebelowheader=Line("", "=", " ", ""),
linebetweenrows=None,
linebelow=Line("", "=", " ", ""),
headerrow=DataRow("", " ", ""),
datarow=DataRow("", " ", ""),
padding=0,
with_header_hide=None,
),
"mediawiki": TableFormat(
lineabove=Line(
'{| class="wikitable" style="text-align: left;"',
"",
"",
"\n|+ <!-- caption -->\n|-",
),
linebelowheader=Line("|-", "", "", ""),
linebetweenrows=Line("|-", "", "", ""),
linebelow=Line("|}", "", "", ""),
headerrow=partial(_mediawiki_row_with_attrs, "!"),
datarow=partial(_mediawiki_row_with_attrs, "|"),
padding=0,
with_header_hide=None,
),
"moinmoin": TableFormat(
lineabove=None,
linebelowheader=None,
linebetweenrows=None,
linebelow=None,
headerrow=partial(_moin_row_with_attrs, "||", header="'''"),
datarow=partial(_moin_row_with_attrs, "||"),
padding=1,
with_header_hide=None,
),
"youtrack": TableFormat(
lineabove=None,
linebelowheader=None,
linebetweenrows=None,
linebelow=None,
headerrow=DataRow("|| ", " || ", " || "),
datarow=DataRow("| ", " | ", " |"),
padding=1,
with_header_hide=None,
),
"html": TableFormat(
lineabove=_html_begin_table_without_header,
linebelowheader="",
linebetweenrows=None,
linebelow=Line("</tbody>\n</table>", "", "", ""),
headerrow=partial(_html_row_with_attrs, "th", False),
datarow=partial(_html_row_with_attrs, "td", False),
padding=0,
with_header_hide=["lineabove"],
),
"unsafehtml": TableFormat(
lineabove=_html_begin_table_without_header,
linebelowheader="",
linebetweenrows=None,
linebelow=Line("</tbody>\n</table>", "", "", ""),
headerrow=partial(_html_row_with_attrs, "th", True),
datarow=partial(_html_row_with_attrs, "td", True),
padding=0,
with_header_hide=["lineabove"],
),
"latex": TableFormat(
lineabove=_latex_line_begin_tabular,
linebelowheader=Line("\\hline", "", "", ""),
linebetweenrows=None,
linebelow=Line("\\hline\n\\end{tabular}", "", "", ""),
headerrow=_latex_row,
datarow=_latex_row,
padding=1,
with_header_hide=None,
),
"latex_raw": TableFormat(
lineabove=_latex_line_begin_tabular,
linebelowheader=Line("\\hline", "", "", ""),
linebetweenrows=None,
linebelow=Line("\\hline\n\\end{tabular}", "", "", ""),
headerrow=partial(_latex_row, escrules={}),
datarow=partial(_latex_row, escrules={}),
padding=1,
with_header_hide=None,
),
"latex_booktabs": TableFormat(
lineabove=partial(_latex_line_begin_tabular, booktabs=True),
linebelowheader=Line("\\midrule", "", "", ""),
linebetweenrows=None,
linebelow=Line("\\bottomrule\n\\end{tabular}", "", "", ""),
headerrow=_latex_row,
datarow=_latex_row,
padding=1,
with_header_hide=None,
),
"latex_longtable": TableFormat(
lineabove=partial(_latex_line_begin_tabular, longtable=True),
linebelowheader=Line("\\hline\n\\endhead", "", "", ""),
linebetweenrows=None,
linebelow=Line("\\hline\n\\end{longtable}", "", "", ""),
headerrow=_latex_row,
datarow=_latex_row,
padding=1,
with_header_hide=None,
),
"tsv": TableFormat(
lineabove=None,
linebelowheader=None,
linebetweenrows=None,
linebelow=None,
headerrow=DataRow("", "\t", ""),
datarow=DataRow("", "\t", ""),
padding=0,
with_header_hide=None,
),
"textile": TableFormat(
lineabove=None,
linebelowheader=None,
linebetweenrows=None,
linebelow=None,
headerrow=DataRow("|_. ", "|_.", "|"),
datarow=_textile_row_with_attrs,
padding=1,
with_header_hide=None,
),
}
tabulate_formats = list(sorted(_table_formats.keys()))
# The table formats for which multiline cells will be folded into subsequent
# table rows. The key is the original format specified at the API. The value is
# the format that will be used to represent the original format.
multiline_formats = {
"plain": "plain",
"simple": "simple",
"grid": "grid",
"fancy_grid": "fancy_grid",
"pipe": "pipe",
"orgtbl": "orgtbl",
"jira": "jira",
"presto": "presto",
"pretty": "pretty",
"psql": "psql",
"rst": "rst",
}
# TODO: Add multiline support for the remaining table formats:
# - mediawiki: Replace \n with <br>
# - moinmoin: TBD
# - youtrack: TBD
# - html: Replace \n with <br>
# - latex*: Use "makecell" package: In header, replace X\nY with
# \thead{X\\Y} and in data row, replace X\nY with \makecell{X\\Y}
# - tsv: TBD
# - textile: Replace \n with <br/> (must be well-formed XML)
_multiline_codes = re.compile(r"\r|\n|\r\n")
_multiline_codes_bytes = re.compile(b"\r|\n|\r\n")
_invisible_codes = re.compile(
r"\x1b\[\d+[;\d]*m|\x1b\[\d*\;\d*\;\d*m|\x1b\]8;;(.*?)\x1b\\"
) # ANSI color codes
_invisible_codes_bytes = re.compile(
b"\x1b\\[\\d+\\[;\\d]*m|\x1b\\[\\d*;\\d*;\\d*m|\\x1b\\]8;;(.*?)\\x1b\\\\"
) # ANSI color codes
_invisible_codes_link = re.compile(
r"\x1B]8;[a-zA-Z0-9:]*;[^\x1B]+\x1B\\([^\x1b]+)\x1B]8;;\x1B\\"
) # Terminal hyperlinks
def simple_separated_format(separator):
"""Construct a simple TableFormat with columns separated by a separator.
>>> tsv = simple_separated_format("\\t") ; \
tabulate([["foo", 1], ["spam", 23]], tablefmt=tsv) == 'foo \\t 1\\nspam\\t23'
True
"""
return TableFormat(
None,
None,
None,
None,
headerrow=DataRow("", separator, ""),
datarow=DataRow("", separator, ""),
padding=0,
with_header_hide=None,
)
def _isconvertible(conv, string):
try:
conv(string)
return True
except (ValueError, TypeError):
return False
def _isnumber(string):
"""
>>> _isnumber("123.45")
True
>>> _isnumber("123")
True
>>> _isnumber("spam")
False
>>> _isnumber("123e45678")
False
>>> _isnumber("inf")
True
"""
if not _isconvertible(float, string):
return False
elif isinstance(string, (_text_type, _binary_type)) and (
math.isinf(float(string)) or math.isnan(float(string))
):
return string.lower() in ["inf", "-inf", "nan"]
return True
def _isint(string, inttype=int):
"""
>>> _isint("123")
True
>>> _isint("123.45")
False
"""
return (
type(string) is inttype
or (isinstance(string, _binary_type) or isinstance(string, _text_type))
and _isconvertible(inttype, string)
)
def _isbool(string):
"""
>>> _isbool(True)
True
>>> _isbool("False")
True
>>> _isbool(1)
False
"""
return type(string) is _bool_type or (
isinstance(string, (_binary_type, _text_type)) and string in ("True", "False")
)
def _type(string, has_invisible=True, numparse=True):
"""The least generic type (type(None), int, float, str, unicode).
>>> _type(None) is type(None)
True
>>> _type("foo") is type("")
True
>>> _type("1") is type(1)
True
>>> _type('\x1b[31m42\x1b[0m') is type(42)
True
>>> _type('\x1b[31m42\x1b[0m') is type(42)
True
"""
if has_invisible and (
isinstance(string, _text_type) or isinstance(string, _binary_type)
):
string = _strip_invisible(string)
if string is None:
return _none_type
elif hasattr(string, "isoformat"): # datetime.datetime, date, and time
return _text_type
elif _isbool(string):
return _bool_type
elif _isint(string) and numparse:
return int
elif _isint(string, _long_type) and numparse:
return int
elif _isnumber(string) and numparse:
return float
elif isinstance(string, _binary_type):
return _binary_type
else:
return _text_type
def _afterpoint(string):
"""Symbols after a decimal point, -1 if the string lacks the decimal point.
>>> _afterpoint("123.45")
2
>>> _afterpoint("1001")
-1
>>> _afterpoint("eggs")
-1
>>> _afterpoint("123e45")
2
"""
if _isnumber(string):
if _isint(string):
return -1
else:
pos = string.rfind(".")
pos = string.lower().rfind("e") if pos < 0 else pos
if pos >= 0:
return len(string) - pos - 1
else:
return -1 # no point
else:
return -1 # not a number
def _padleft(width, s):
"""Flush right.
>>> _padleft(6, '\u044f\u0439\u0446\u0430') == ' \u044f\u0439\u0446\u0430'
True
"""
fmt = "{0:>%ds}" % width
return fmt.format(s)
def _padright(width, s):
"""Flush left.
>>> _padright(6, '\u044f\u0439\u0446\u0430') == '\u044f\u0439\u0446\u0430 '
True
"""
fmt = "{0:<%ds}" % width
return fmt.format(s)
def _padboth(width, s):
"""Center string.
>>> _padboth(6, '\u044f\u0439\u0446\u0430') == ' \u044f\u0439\u0446\u0430 '
True
"""
fmt = "{0:^%ds}" % width
return fmt.format(s)
def _padnone(ignore_width, s):
return s
def _strip_invisible(s):
r"""Remove invisible ANSI color codes.
>>> str(_strip_invisible('\x1B]8;;https://example.com\x1B\\This is a link\x1B]8;;\x1B\\'))
'This is a link'
"""
if isinstance(s, _text_type):
links_removed = re.sub(_invisible_codes_link, "\\1", s)
return re.sub(_invisible_codes, "", links_removed)
else: # a bytestring
return re.sub(_invisible_codes_bytes, "", s)
def _visible_width(s):
"""Visible width of a printed string. ANSI color codes are removed.
>>> _visible_width('\x1b[31mhello\x1b[0m'), _visible_width("world")
(5, 5)
"""
# optional wide-character support
if wcwidth is not None and WIDE_CHARS_MODE:
len_fn = wcwidth.wcswidth
else:
len_fn = len
if isinstance(s, _text_type) or isinstance(s, _binary_type):
return len_fn(_strip_invisible(s))
else:
return len_fn(_text_type(s))
def _is_multiline(s):
if isinstance(s, _text_type):
return bool(re.search(_multiline_codes, s))
else: # a bytestring
return bool(re.search(_multiline_codes_bytes, s))
def _multiline_width(multiline_s, line_width_fn=len):
"""Visible width of a potentially multiline content."""
return max(map(line_width_fn, re.split("[\r\n]", multiline_s)))
def _choose_width_fn(has_invisible, enable_widechars, is_multiline):
"""Return a function to calculate visible cell width."""
if has_invisible:
line_width_fn = _visible_width
elif enable_widechars: # optional wide-character support if available
line_width_fn = wcwidth.wcswidth
else:
line_width_fn = len
if is_multiline:
width_fn = lambda s: _multiline_width(s, line_width_fn) # noqa
else:
width_fn = line_width_fn
return width_fn
def _align_column_choose_padfn(strings, alignment, has_invisible):
if alignment == "right":
if not PRESERVE_WHITESPACE:
strings = [s.strip() for s in strings]
padfn = _padleft
elif alignment == "center":
if not PRESERVE_WHITESPACE:
strings = [s.strip() for s in strings]
padfn = _padboth
elif alignment == "decimal":
if has_invisible:
decimals = [_afterpoint(_strip_invisible(s)) for s in strings]
else:
decimals = [_afterpoint(s) for s in strings]
maxdecimals = max(decimals)
strings = [s + (maxdecimals - decs) * " " for s, decs in zip(strings, decimals)]
padfn = _padleft
elif not alignment:
padfn = _padnone
else:
if not PRESERVE_WHITESPACE:
strings = [s.strip() for s in strings]
padfn = _padright
return strings, padfn
def _align_column_choose_width_fn(has_invisible, enable_widechars, is_multiline):
if has_invisible:
line_width_fn = _visible_width
elif enable_widechars: # optional wide-character support if available
line_width_fn = wcwidth.wcswidth
else:
line_width_fn = len
if is_multiline:
width_fn = lambda s: _align_column_multiline_width(s, line_width_fn) # noqa
else:
width_fn = line_width_fn
return width_fn
def _align_column_multiline_width(multiline_s, line_width_fn=len):
"""Visible width of a potentially multiline content."""
return list(map(line_width_fn, re.split("[\r\n]", multiline_s)))
def _flat_list(nested_list):
ret = []
for item in nested_list:
if isinstance(item, list):
for subitem in item:
ret.append(subitem)
else:
ret.append(item)
return ret
def _align_column(
strings,
alignment,
minwidth=0,
has_invisible=True,
enable_widechars=False,
is_multiline=False,
):
"""[string] -> [padded_string]"""
strings, padfn = _align_column_choose_padfn(strings, alignment, has_invisible)
width_fn = _align_column_choose_width_fn(
has_invisible, enable_widechars, is_multiline
)
s_widths = list(map(width_fn, strings))
maxwidth = max(max(_flat_list(s_widths)), minwidth)
# TODO: refactor column alignment in single-line and multiline modes
if is_multiline:
if not enable_widechars and not has_invisible:
padded_strings = [
"\n".join([padfn(maxwidth, s) for s in ms.splitlines()])
for ms in strings
]
else:
# enable wide-character width corrections
s_lens = [[len(s) for s in re.split("[\r\n]", ms)] for ms in strings]
visible_widths = [
[maxwidth - (w - l) for w, l in zip(mw, ml)]
for mw, ml in zip(s_widths, s_lens)
]
# wcswidth and _visible_width don't count invisible characters;
# padfn doesn't need to apply another correction
padded_strings = [
"\n".join([padfn(w, s) for s, w in zip((ms.splitlines() or ms), mw)])
for ms, mw in zip(strings, visible_widths)
]
else: # single-line cell values
if not enable_widechars and not has_invisible:
padded_strings = [padfn(maxwidth, s) for s in strings]
else:
# enable wide-character width corrections
s_lens = list(map(len, strings))
visible_widths = [maxwidth - (w - l) for w, l in zip(s_widths, s_lens)]
# wcswidth and _visible_width don't count invisible characters;
# padfn doesn't need to apply another correction
padded_strings = [padfn(w, s) for s, w in zip(strings, visible_widths)]
return padded_strings
def _more_generic(type1, type2):
types = {
_none_type: 0,
_bool_type: 1,
int: 2,
float: 3,
_binary_type: 4,
_text_type: 5,
}
invtypes = {
5: _text_type,
4: _binary_type,
3: float,
2: int,
1: _bool_type,
0: _none_type,
}
moregeneric = max(types.get(type1, 5), types.get(type2, 5))
return invtypes[moregeneric]
def _column_type(strings, has_invisible=True, numparse=True):
"""The least generic type all column values are convertible to.
>>> _column_type([True, False]) is _bool_type
True
>>> _column_type(["1", "2"]) is _int_type
True
>>> _column_type(["1", "2.3"]) is _float_type
True
>>> _column_type(["1", "2.3", "four"]) is _text_type
True
>>> _column_type(["four", '\u043f\u044f\u0442\u044c']) is _text_type
True
>>> _column_type([None, "brux"]) is _text_type
True
>>> _column_type([1, 2, None]) is _int_type
True
>>> import datetime as dt
>>> _column_type([dt.datetime(1991,2,19), dt.time(17,35)]) is _text_type
True
"""
types = [_type(s, has_invisible, numparse) for s in strings]
return reduce(_more_generic, types, _bool_type)
def _format(val, valtype, floatfmt, missingval="", has_invisible=True):
"""Format a value according to its type.
Unicode is supported:
>>> hrow = ['\u0431\u0443\u043a\u0432\u0430', '\u0446\u0438\u0444\u0440\u0430'] ; \
tbl = [['\u0430\u0437', 2], ['\u0431\u0443\u043a\u0438', 4]] ; \
good_result = '\\u0431\\u0443\\u043a\\u0432\\u0430 \\u0446\\u0438\\u0444\\u0440\\u0430\\n------- -------\\n\\u0430\\u0437 2\\n\\u0431\\u0443\\u043a\\u0438 4' ; \
tabulate(tbl, headers=hrow) == good_result
True
""" # noqa
if val is None:
return missingval
if valtype in [int, _text_type]:
return "{0}".format(val)
elif valtype is _binary_type:
try:
return _text_type(val, "ascii")
except TypeError:
return _text_type(val)
elif valtype is float:
is_a_colored_number = has_invisible and isinstance(
val, (_text_type, _binary_type)
)
if is_a_colored_number:
raw_val = _strip_invisible(val)
formatted_val = format(float(raw_val), floatfmt)
return val.replace(raw_val, formatted_val)
else:
return format(float(val), floatfmt)
else:
return "{0}".format(val)
def _align_header(
header, alignment, width, visible_width, is_multiline=False, width_fn=None
):
"Pad string header to width chars given known visible_width of the header."
if is_multiline:
header_lines = re.split(_multiline_codes, header)
padded_lines = [
_align_header(h, alignment, width, width_fn(h)) for h in header_lines
]
return "\n".join(padded_lines)
# else: not multiline
ninvisible = len(header) - visible_width
width += ninvisible
if alignment == "left":
return _padright(width, header)
elif alignment == "center":
return _padboth(width, header)
elif not alignment:
return "{0}".format(header)
else:
return _padleft(width, header)
def _prepend_row_index(rows, index):
"""Add a left-most index column."""
if index is None or index is False:
return rows
if len(index) != len(rows):
print("index=", index)
print("rows=", rows)
raise ValueError("index must be as long as the number of data rows")
rows = [[v] + list(row) for v, row in zip(index, rows)]
return rows
def _bool(val):
"A wrapper around standard bool() which doesn't throw on NumPy arrays"
try:
return bool(val)
except ValueError: # val is likely to be a numpy array with many elements
return False
def _normalize_tabular_data(tabular_data, headers, showindex="default"):
"""Transform a supported data type to a list of lists, and a list of headers.
Supported tabular data types:
* list-of-lists or another iterable of iterables
* list of named tuples (usually used with headers="keys")
* list of dicts (usually used with headers="keys")
* list of OrderedDicts (usually used with headers="keys")
* 2D NumPy arrays
* NumPy record arrays (usually used with headers="keys")
* dict of iterables (usually used with headers="keys")
* pandas.DataFrame (usually used with headers="keys")
The first row can be used as headers if headers="firstrow",
column indices can be used as headers if headers="keys".
If showindex="default", show row indices of the pandas.DataFrame.
If showindex="always", show row indices for all types of data.
If showindex="never", don't show row indices for all types of data.
If showindex is an iterable, show its values as row indices.
"""
try:
bool(headers)
is_headers2bool_broken = False # noqa
except ValueError: # numpy.ndarray, pandas.core.index.Index, ...
is_headers2bool_broken = True # noqa
headers = list(headers)
index = None
if hasattr(tabular_data, "keys") and hasattr(tabular_data, "values"):
# dict-like and pandas.DataFrame?
if hasattr(tabular_data.values, "__call__"):
# likely a conventional dict
keys = tabular_data.keys()
rows = list(
izip_longest(*tabular_data.values())
) # columns have to be transposed
elif hasattr(tabular_data, "index"):
# values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0)
keys = list(tabular_data)
if (
showindex in ["default", "always", True]
and tabular_data.index.name is not None
):
if isinstance(tabular_data.index.name, list):
keys[:0] = tabular_data.index.name
else:
keys[:0] = [tabular_data.index.name]
vals = tabular_data.values # values matrix doesn't need to be transposed
# for DataFrames add an index per default
index = list(tabular_data.index)
rows = [list(row) for row in vals]
else:
raise ValueError("tabular data doesn't appear to be a dict or a DataFrame")
if headers == "keys":
headers = list(map(_text_type, keys)) # headers should be strings
else: # it's a usual an iterable of iterables, or a NumPy array
rows = list(tabular_data)
if headers == "keys" and not rows:
# an empty table (issue #81)
headers = []
elif (
headers == "keys"
and hasattr(tabular_data, "dtype")
and getattr(tabular_data.dtype, "names")
):
# numpy record array
headers = tabular_data.dtype.names
elif (
headers == "keys"
and len(rows) > 0
and isinstance(rows[0], tuple)
and hasattr(rows[0], "_fields")
):
# namedtuple
headers = list(map(_text_type, rows[0]._fields))
elif len(rows) > 0 and hasattr(rows[0], "keys") and hasattr(rows[0], "values"):
# dict-like object
uniq_keys = set() # implements hashed lookup
keys = [] # storage for set
if headers == "firstrow":
firstdict = rows[0] if len(rows) > 0 else {}
keys.extend(firstdict.keys())
uniq_keys.update(keys)
rows = rows[1:]
for row in rows:
for k in row.keys():
# Save unique items in input order
if k not in uniq_keys:
keys.append(k)
uniq_keys.add(k)
if headers == "keys":
headers = keys
elif isinstance(headers, dict):
# a dict of headers for a list of dicts
headers = [headers.get(k, k) for k in keys]
headers = list(map(_text_type, headers))
elif headers == "firstrow":
if len(rows) > 0:
headers = [firstdict.get(k, k) for k in keys]
headers = list(map(_text_type, headers))
else:
headers = []
elif headers:
raise ValueError(
"headers for a list of dicts is not a dict or a keyword"
)
rows = [[row.get(k) for k in keys] for row in rows]
elif (
headers == "keys"
and hasattr(tabular_data, "description")
and hasattr(tabular_data, "fetchone")
and hasattr(tabular_data, "rowcount")
):
# Python Database API cursor object (PEP 0249)
# print tabulate(cursor, headers='keys')
headers = [column[0] for column in tabular_data.description]
elif headers == "keys" and len(rows) > 0:
# keys are column indices
headers = list(map(_text_type, range(len(rows[0]))))
# take headers from the first row if necessary
if headers == "firstrow" and len(rows) > 0:
if index is not None:
headers = [index[0]] + list(rows[0])
index = index[1:]
else:
headers = rows[0]
headers = list(map(_text_type, headers)) # headers should be strings
rows = rows[1:]
headers = list(map(_text_type, headers))
rows = list(map(list, rows))
# add or remove an index column
showindex_is_a_str = type(showindex) in [_text_type, _binary_type]
if showindex == "default" and index is not None:
rows = _prepend_row_index(rows, index)
elif isinstance(showindex, Iterable) and not showindex_is_a_str:
rows = _prepend_row_index(rows, list(showindex))
elif showindex == "always" or (_bool(showindex) and not showindex_is_a_str):
if index is None:
index = list(range(len(rows)))
rows = _prepend_row_index(rows, index)
elif showindex == "never" or (not _bool(showindex) and not showindex_is_a_str):
pass
# pad with empty headers for initial columns if necessary
if headers and len(rows) > 0:
nhs = len(headers)
ncols = len(rows[0])
if nhs < ncols:
headers = [""] * (ncols - nhs) + headers
return rows, headers
def tabulate(
tabular_data,
headers=(),
tablefmt="simple",
floatfmt=_DEFAULT_FLOATFMT,
numalign=_DEFAULT_ALIGN,
stralign=_DEFAULT_ALIGN,
missingval=_DEFAULT_MISSINGVAL,
showindex="default",
disable_numparse=False,
colalign=None,
):
"""Format a fixed width table for pretty printing.
>>> print(tabulate([[1, 2.34], [-56, "8.999"], ["2", "10001"]]))
--- ---------
1 2.34
-56 8.999
2 10001
--- ---------
The first required argument (`tabular_data`) can be a
list-of-lists (or another iterable of iterables), a list of named
tuples, a dictionary of iterables, an iterable of dictionaries,
a two-dimensional NumPy array, NumPy record array, or a Pandas'
dataframe.
Table headers
-------------
To print nice column headers, supply the second argument (`headers`):
- `headers` can be an explicit list of column headers
- if `headers="firstrow"`, then the first row of data is used
- if `headers="keys"`, then dictionary keys or column indices are used
Otherwise a headerless table is produced.
If the number of headers is less than the number of columns, they
are supposed to be names of the last columns. This is consistent
with the plain-text format of R and Pandas' dataframes.
>>> print(tabulate([["sex","age"],["Alice","F",24],["Bob","M",19]],
... headers="firstrow"))
sex age
----- ----- -----
Alice F 24
Bob M 19
By default, pandas.DataFrame data have an additional column called
row index. To add a similar column to all other types of data,
use `showindex="always"` or `showindex=True`. To suppress row indices
for all types of data, pass `showindex="never" or `showindex=False`.
To add a custom row index column, pass `showindex=some_iterable`.
>>> print(tabulate([["F",24],["M",19]], showindex="always"))
- - --
0 F 24
1 M 19
- - --
Column alignment
----------------
`tabulate` tries to detect column types automatically, and aligns
the values properly. By default it aligns decimal points of the
numbers (or flushes integer numbers to the right), and flushes
everything else to the left. Possible column alignments
(`numalign`, `stralign`) are: "right", "center", "left", "decimal"
(only for `numalign`), and None (to disable alignment).
Table formats
-------------
`floatfmt` is a format specification used for columns which
contain numeric data with a decimal point. This can also be
a list or tuple of format strings, one per column.
`None` values are replaced with a `missingval` string (like
`floatfmt`, this can also be a list of values for different
columns):
>>> print(tabulate([["spam", 1, None],
... ["eggs", 42, 3.14],
... ["other", None, 2.7]], missingval="?"))
----- -- ----
spam 1 ?
eggs 42 3.14
other ? 2.7
----- -- ----
Various plain-text table formats (`tablefmt`) are supported:
'plain', 'simple', 'grid', 'pipe', 'orgtbl', 'rst', 'mediawiki',
'latex', 'latex_raw', 'latex_booktabs', 'latex_longtable' and tsv.
Variable `tabulate_formats`contains the list of currently supported formats.
"plain" format doesn't use any pseudographics to draw tables,
it separates columns with a double space:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
... ["strings", "numbers"], "plain"))
strings numbers
spam 41.9999
eggs 451
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="plain"))
spam 41.9999
eggs 451
"simple" format is like Pandoc simple_tables:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
... ["strings", "numbers"], "simple"))
strings numbers
--------- ---------
spam 41.9999
eggs 451
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="simple"))
---- --------
spam 41.9999
eggs 451
---- --------
"grid" is similar to tables produced by Emacs table.el package or
Pandoc grid_tables:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
... ["strings", "numbers"], "grid"))
+-----------+-----------+
| strings | numbers |
+===========+===========+
| spam | 41.9999 |
+-----------+-----------+
| eggs | 451 |
+-----------+-----------+
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="grid"))
+------+----------+
| spam | 41.9999 |
+------+----------+
| eggs | 451 |
+------+----------+
"fancy_grid" draws a grid using box-drawing characters:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
... ["strings", "numbers"], "fancy_grid"))
╒═══════════╤═══════════╕
│ strings │ numbers │
╞═══════════╪═══════════╡
│ spam │ 41.9999 │
├───────────┼───────────┤
│ eggs │ 451 │
╘═══════════╧═══════════╛
"pipe" is like tables in PHP Markdown Extra extension or Pandoc
pipe_tables:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
... ["strings", "numbers"], "pipe"))
| strings | numbers |
|:----------|----------:|
| spam | 41.9999 |
| eggs | 451 |
"presto" is like tables produce by the Presto CLI:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
... ["strings", "numbers"], "presto"))
strings | numbers
-----------+-----------
spam | 41.9999
eggs | 451
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="pipe"))
|:-----|---------:|
| spam | 41.9999 |
| eggs | 451 |
"orgtbl" is like tables in Emacs org-mode and orgtbl-mode. They
are slightly different from "pipe" format by not using colons to
define column alignment, and using a "+" sign to indicate line
intersections:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
... ["strings", "numbers"], "orgtbl"))
| strings | numbers |
|-----------+-----------|
| spam | 41.9999 |
| eggs | 451 |
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="orgtbl"))
| spam | 41.9999 |
| eggs | 451 |
"rst" is like a simple table format from reStructuredText; please
note that reStructuredText accepts also "grid" tables:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]],
... ["strings", "numbers"], "rst"))
========= =========
strings numbers
========= =========
spam 41.9999
eggs 451
========= =========
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="rst"))
==== ========
spam 41.9999
eggs 451
==== ========
"mediawiki" produces a table markup used in Wikipedia and on other
MediaWiki-based sites:
>>> print(tabulate([["strings", "numbers"], ["spam", 41.9999], ["eggs", "451.0"]],
... headers="firstrow", tablefmt="mediawiki"))
{| class="wikitable" style="text-align: left;"
|+ <!-- caption -->
|-
! strings !! align="right"| numbers
|-
| spam || align="right"| 41.9999
|-
| eggs || align="right"| 451
|}
"html" produces HTML markup as an html.escape'd str
with a ._repr_html_ method so that Jupyter Lab and Notebook display the HTML
and a .str property so that the raw HTML remains accessible
the unsafehtml table format can be used if an unescaped HTML format is required:
>>> print(tabulate([["strings", "numbers"], ["spam", 41.9999], ["eggs", "451.0"]],
... headers="firstrow", tablefmt="html"))
<table>
<thead>
<tr><th>strings </th><th style="text-align: right;"> numbers</th></tr>
</thead>
<tbody>
<tr><td>spam </td><td style="text-align: right;"> 41.9999</td></tr>
<tr><td>eggs </td><td style="text-align: right;"> 451 </td></tr>
</tbody>
</table>
"latex" produces a tabular environment of LaTeX document markup:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="latex"))
\\begin{tabular}{lr}
\\hline
spam & 41.9999 \\\\
eggs & 451 \\\\
\\hline
\\end{tabular}
"latex_raw" is similar to "latex", but doesn't escape special characters,
such as backslash and underscore, so LaTeX commands may embedded into
cells' values:
>>> print(tabulate([["spam$_9$", 41.9999], ["\\\\emph{eggs}", "451.0"]], tablefmt="latex_raw"))
\\begin{tabular}{lr}
\\hline
spam$_9$ & 41.9999 \\\\
\\emph{eggs} & 451 \\\\
\\hline
\\end{tabular}
"latex_booktabs" produces a tabular environment of LaTeX document markup
using the booktabs.sty package:
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="latex_booktabs"))
\\begin{tabular}{lr}
\\toprule
spam & 41.9999 \\\\
eggs & 451 \\\\
\\bottomrule
\\end{tabular}
"latex_longtable" produces a tabular environment that can stretch along
multiple pages, using the longtable package for LaTeX.
>>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="latex_longtable"))
\\begin{longtable}{lr}
\\hline
spam & 41.9999 \\\\
eggs & 451 \\\\
\\hline
\\end{longtable}
Number parsing
--------------
By default, anything which can be parsed as a number is a number.
This ensures numbers represented as strings are aligned properly.
This can lead to weird results for particular strings such as
specific git SHAs e.g. "42992e1" will be parsed into the number
429920 and aligned as such.
To completely disable number parsing (and alignment), use
`disable_numparse=True`. For more fine grained control, a list column
indices is used to disable number parsing only on those columns
e.g. `disable_numparse=[0, 2]` would disable number parsing only on the
first and third columns.
"""
if tabular_data is None:
tabular_data = []
list_of_lists, headers = _normalize_tabular_data(
tabular_data, headers, showindex=showindex
)
# empty values in the first column of RST tables should be escaped (issue #82)
# "" should be escaped as "\\ " or ".."
if tablefmt == "rst":
list_of_lists, headers = _rst_escape_first_column(list_of_lists, headers)
# PrettyTable formatting does not use any extra padding.
# Numbers are not parsed and are treated the same as strings for alignment.
# Check if pretty is the format being used and override the defaults so it
# does not impact other formats.
min_padding = MIN_PADDING
if tablefmt == "pretty":
min_padding = 0
disable_numparse = True
numalign = "center" if numalign == _DEFAULT_ALIGN else numalign
stralign = "center" if stralign == _DEFAULT_ALIGN else stralign
else:
numalign = "decimal" if numalign == _DEFAULT_ALIGN else numalign
stralign = "left" if stralign == _DEFAULT_ALIGN else stralign
# optimization: look for ANSI control codes once,
# enable smart width functions only if a control code is found
plain_text = "\t".join(
["\t".join(map(_text_type, headers))]
+ ["\t".join(map(_text_type, row)) for row in list_of_lists]
)
has_invisible = re.search(_invisible_codes, plain_text)
if not has_invisible:
has_invisible = re.search(_invisible_codes_link, plain_text)
enable_widechars = wcwidth is not None and WIDE_CHARS_MODE
if (
not isinstance(tablefmt, TableFormat)
and tablefmt in multiline_formats
and _is_multiline(plain_text)
):
tablefmt = multiline_formats.get(tablefmt, tablefmt)
is_multiline = True
else:
is_multiline = False
width_fn = _choose_width_fn(has_invisible, enable_widechars, is_multiline)
# format rows and columns, convert numeric values to strings
cols = list(izip_longest(*list_of_lists))
numparses = _expand_numparse(disable_numparse, len(cols))
coltypes = [_column_type(col, numparse=np) for col, np in zip(cols, numparses)]
if isinstance(floatfmt, basestring): # old version
float_formats = len(cols) * [
floatfmt
] # just duplicate the string to use in each column
else: # if floatfmt is list, tuple etc we have one per column
float_formats = list(floatfmt)
if len(float_formats) < len(cols):
float_formats.extend((len(cols) - len(float_formats)) * [_DEFAULT_FLOATFMT])
if isinstance(missingval, basestring):
missing_vals = len(cols) * [missingval]
else:
missing_vals = list(missingval)
if len(missing_vals) < len(cols):
missing_vals.extend((len(cols) - len(missing_vals)) * [_DEFAULT_MISSINGVAL])
cols = [
[_format(v, ct, fl_fmt, miss_v, has_invisible) for v in c]
for c, ct, fl_fmt, miss_v in zip(cols, coltypes, float_formats, missing_vals)
]
# align columns
aligns = [numalign if ct in [int, float] else stralign for ct in coltypes]
if colalign is not None:
assert isinstance(colalign, Iterable)
for idx, align in enumerate(colalign):
aligns[idx] = align
minwidths = (
[width_fn(h) + min_padding for h in headers] if headers else [0] * len(cols)
)
cols = [
_align_column(c, a, minw, has_invisible, enable_widechars, is_multiline)
for c, a, minw in zip(cols, aligns, minwidths)
]
if headers:
# align headers and add headers
t_cols = cols or [[""]] * len(headers)
t_aligns = aligns or [stralign] * len(headers)
minwidths = [
max(minw, max(width_fn(cl) for cl in c))
for minw, c in zip(minwidths, t_cols)
]
headers = [
_align_header(h, a, minw, width_fn(h), is_multiline, width_fn)
for h, a, minw in zip(headers, t_aligns, minwidths)
]
rows = list(zip(*cols))
else:
minwidths = [max(width_fn(cl) for cl in c) for c in cols]
rows = list(zip(*cols))
if not isinstance(tablefmt, TableFormat):
tablefmt = _table_formats.get(tablefmt, _table_formats["simple"])
return _format_table(tablefmt, headers, rows, minwidths, aligns, is_multiline)
def _expand_numparse(disable_numparse, column_count):
"""
Return a list of bools of length `column_count` which indicates whether
number parsing should be used on each column.
If `disable_numparse` is a list of indices, each of those indices are False,
and everything else is True.
If `disable_numparse` is a bool, then the returned list is all the same.
"""
if isinstance(disable_numparse, Iterable):
numparses = [True] * column_count
for index in disable_numparse:
numparses[index] = False
return numparses
else:
return [not disable_numparse] * column_count
def _pad_row(cells, padding):
if cells:
pad = " " * padding
padded_cells = [pad + cell + pad for cell in cells]
return padded_cells
else:
return cells
def _build_simple_row(padded_cells, rowfmt):
"Format row according to DataRow format without padding."
begin, sep, end = rowfmt
return (begin + sep.join(padded_cells) + end).rstrip()
def _build_row(padded_cells, colwidths, colaligns, rowfmt):
"Return a string which represents a row of data cells."
if not rowfmt:
return None
if hasattr(rowfmt, "__call__"):
return rowfmt(padded_cells, colwidths, colaligns)
else:
return _build_simple_row(padded_cells, rowfmt)
def _append_basic_row(lines, padded_cells, colwidths, colaligns, rowfmt):
lines.append(_build_row(padded_cells, colwidths, colaligns, rowfmt))
return lines
def _append_multiline_row(
lines, padded_multiline_cells, padded_widths, colaligns, rowfmt, pad
):
colwidths = [w - 2 * pad for w in padded_widths]
cells_lines = [c.splitlines() for c in padded_multiline_cells]
nlines = max(map(len, cells_lines)) # number of lines in the row
# vertically pad cells where some lines are missing
cells_lines = [
(cl + [" " * w] * (nlines - len(cl))) for cl, w in zip(cells_lines, colwidths)
]
lines_cells = [[cl[i] for cl in cells_lines] for i in range(nlines)]
for ln in lines_cells:
padded_ln = _pad_row(ln, pad)
_append_basic_row(lines, padded_ln, colwidths, colaligns, rowfmt)
return lines
def _build_line(colwidths, colaligns, linefmt):
"Return a string which represents a horizontal line."
if not linefmt:
return None
if hasattr(linefmt, "__call__"):
return linefmt(colwidths, colaligns)
else:
begin, fill, sep, end = linefmt
cells = [fill * w for w in colwidths]
return _build_simple_row(cells, (begin, sep, end))
def _append_line(lines, colwidths, colaligns, linefmt):
lines.append(_build_line(colwidths, colaligns, linefmt))
return lines
class JupyterHTMLStr(str):
"""Wrap the string with a _repr_html_ method so that Jupyter
displays the HTML table"""
def _repr_html_(self):
return self
@property
def str(self):
"""add a .str property so that the raw string is still accessible"""
return self
def _format_table(fmt, headers, rows, colwidths, colaligns, is_multiline):
"""Produce a plain-text representation of the table."""
lines = []
hidden = fmt.with_header_hide if (headers and fmt.with_header_hide) else []
pad = fmt.padding
headerrow = fmt.headerrow
padded_widths = [(w + 2 * pad) for w in colwidths]
if is_multiline:
pad_row = lambda row, _: row # noqa do it later, in _append_multiline_row
append_row = partial(_append_multiline_row, pad=pad)
else:
pad_row = _pad_row
append_row = _append_basic_row
padded_headers = pad_row(headers, pad)
padded_rows = [pad_row(row, pad) for row in rows]
if fmt.lineabove and "lineabove" not in hidden:
_append_line(lines, padded_widths, colaligns, fmt.lineabove)
if padded_headers:
append_row(lines, padded_headers, padded_widths, colaligns, headerrow)
if fmt.linebelowheader and "linebelowheader" not in hidden:
_append_line(lines, padded_widths, colaligns, fmt.linebelowheader)
if padded_rows and fmt.linebetweenrows and "linebetweenrows" not in hidden:
# initial rows with a line below
for row in padded_rows[:-1]:
append_row(lines, row, padded_widths, colaligns, fmt.datarow)
_append_line(lines, padded_widths, colaligns, fmt.linebetweenrows)
# the last row without a line below
append_row(lines, padded_rows[-1], padded_widths, colaligns, fmt.datarow)
else:
for row in padded_rows:
append_row(lines, row, padded_widths, colaligns, fmt.datarow)
if fmt.linebelow and "linebelow" not in hidden:
_append_line(lines, padded_widths, colaligns, fmt.linebelow)
if headers or rows:
output = "\n".join(lines)
if fmt.lineabove == _html_begin_table_without_header:
return JupyterHTMLStr(output)
else:
return output
else: # a completely empty table
return ""
def _main():
"""\
Usage: tabulate [options] [FILE ...]
Pretty-print tabular data.
See also https://github.com/astanin/python-tabulate
FILE a filename of the file with tabular data;
if "-" or missing, read data from stdin.
Options:
-h, --help show this message
-1, --header use the first row of data as a table header
-o FILE, --output FILE print table to FILE (default: stdout)
-s REGEXP, --sep REGEXP use a custom column separator (default: whitespace)
-F FPFMT, --float FPFMT floating point number format (default: g)
-f FMT, --format FMT set output table format; supported formats:
plain, simple, grid, fancy_grid, pipe, orgtbl,
rst, mediawiki, html, latex, latex_raw,
latex_booktabs, latex_longtable, tsv
(default: simple)
"""
import getopt
import sys
import textwrap
usage = textwrap.dedent(_main.__doc__)
try:
opts, args = getopt.getopt(
sys.argv[1:],
"h1o:s:F:A:f:",
["help", "header", "output", "sep=", "float=", "align=", "format="],
)
except getopt.GetoptError as e:
print(e)
print(usage)
sys.exit(2)
headers = []
floatfmt = _DEFAULT_FLOATFMT
colalign = None
tablefmt = "simple"
sep = r"\s+"
outfile = "-"
for opt, value in opts:
if opt in ["-1", "--header"]:
headers = "firstrow"
elif opt in ["-o", "--output"]:
outfile = value
elif opt in ["-F", "--float"]:
floatfmt = value
elif opt in ["-C", "--colalign"]:
colalign = value.split()
elif opt in ["-f", "--format"]:
if value not in tabulate_formats:
print("%s is not a supported table format" % value)
print(usage)
sys.exit(3)
tablefmt = value
elif opt in ["-s", "--sep"]:
sep = value
elif opt in ["-h", "--help"]:
print(usage)
sys.exit(0)
files = [sys.stdin] if not args else args
with (sys.stdout if outfile == "-" else open(outfile, "w")) as out:
for f in files:
if f == "-":
f = sys.stdin
if _is_file(f):
_pprint_file(
f,
headers=headers,
tablefmt=tablefmt,
sep=sep,
floatfmt=floatfmt,
file=out,
colalign=colalign,
)
else:
with open(f) as fobj:
_pprint_file(
fobj,
headers=headers,
tablefmt=tablefmt,
sep=sep,
floatfmt=floatfmt,
file=out,
colalign=colalign,
)
def _pprint_file(fobject, headers, tablefmt, sep, floatfmt, file, colalign):
rows = fobject.readlines()
table = [re.split(sep, r.rstrip()) for r in rows if r.strip()]
print(
tabulate(table, headers, tablefmt, floatfmt=floatfmt, colalign=colalign),
file=file,
)
if __name__ == "__main__":
_main()
# coding: utf-8
# Copyright (c) 2008-2011 Volvox Development Team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# Author: Konstantin Lepa <konstantin.lepa@gmail.com>
"""ANSII Color formatting for output in terminal."""
from __future__ import print_function
import os
__ALL__ = [ 'colored', 'cprint' ]
VERSION = (1, 1, 0)
ATTRIBUTES = dict(
list(zip([
'bold',
'dark',
'',
'underline',
'blink',
'',
'reverse',
'concealed'
],
list(range(1, 9))
))
)
del ATTRIBUTES['']
HIGHLIGHTS = dict(
list(zip([
'on_grey',
'on_red',
'on_green',
'on_yellow',
'on_blue',
'on_magenta',
'on_cyan',
'on_white'
],
list(range(40, 48))
))
)
COLORS = dict(
list(zip([
'grey',
'red',
'green',
'yellow',
'blue',
'magenta',
'cyan',
'white',
],
list(range(30, 38))
))
)
RESET = '\033[0m'
def colored(text, color=None, on_color=None, attrs=None):
"""Colorize text.
Available text colors:
red, green, yellow, blue, magenta, cyan, white.
Available text highlights:
on_red, on_green, on_yellow, on_blue, on_magenta, on_cyan, on_white.
Available attributes:
bold, dark, underline, blink, reverse, concealed.
Example:
colored('Hello, World!', 'red', 'on_grey', ['blue', 'blink'])
colored('Hello, World!', 'green')
"""
if os.getenv('ANSI_COLORS_DISABLED') is None:
fmt_str = '\033[%dm%s'
if color is not None:
text = fmt_str % (COLORS[color], text)
if on_color is not None:
text = fmt_str % (HIGHLIGHTS[on_color], text)
if attrs is not None:
for attr in attrs:
text = fmt_str % (ATTRIBUTES[attr], text)
text += RESET
return text
def cprint(text, color=None, on_color=None, attrs=None, **kwargs):
"""Print colorize text.
It accepts arguments of print function.
"""
print((colored(text, color, on_color, attrs)), **kwargs)
if __name__ == '__main__':
print('Current terminal type: %s' % os.getenv('TERM'))
print('Test basic colors:')
cprint('Grey color', 'grey')
cprint('Red color', 'red')
cprint('Green color', 'green')
cprint('Yellow color', 'yellow')
cprint('Blue color', 'blue')
cprint('Magenta color', 'magenta')
cprint('Cyan color', 'cyan')
cprint('White color', 'white')
print(('-' * 78))
print('Test highlights:')
cprint('On grey color', on_color='on_grey')
cprint('On red color', on_color='on_red')
cprint('On green color', on_color='on_green')
cprint('On yellow color', on_color='on_yellow')
cprint('On blue color', on_color='on_blue')
cprint('On magenta color', on_color='on_magenta')
cprint('On cyan color', on_color='on_cyan')
cprint('On white color', color='grey', on_color='on_white')
print('-' * 78)
print('Test attributes:')
cprint('Bold grey color', 'grey', attrs=['bold'])
cprint('Dark red color', 'red', attrs=['dark'])
cprint('Underline green color', 'green', attrs=['underline'])
cprint('Blink yellow color', 'yellow', attrs=['blink'])
cprint('Reversed blue color', 'blue', attrs=['reverse'])
cprint('Concealed Magenta color', 'magenta', attrs=['concealed'])
cprint('Bold underline reverse cyan color', 'cyan',
attrs=['bold', 'underline', 'reverse'])
cprint('Dark blink concealed white color', 'white',
attrs=['dark', 'blink', 'concealed'])
print(('-' * 78))
print('Test mixing:')
cprint('Underline red on grey color', 'red', 'on_grey',
['underline'])
cprint('Reversed green on red color', 'green', 'on_red', ['reverse'])
"""threadpoolctl
This module provides utilities to introspect native libraries that relies on
thread pools (notably BLAS and OpenMP implementations) and dynamically set the
maximal number of threads they can use.
"""
# License: BSD 3-Clause
# The code to introspect dynamically loaded libraries on POSIX systems is
# adapted from code by Intel developper @anton-malakhov available at
# https://github.com/IntelPython/smp (Copyright (c) 2017, Intel Corporation)
# and also published under the BSD 3-Clause license
import os
import re
import sys
import ctypes
import textwrap
import warnings
from ctypes.util import find_library
from abc import ABC, abstractmethod
from functools import lru_cache
from contextlib import ContextDecorator
__version__ = "3.0.0"
__all__ = ["threadpool_limits", "threadpool_info", "ThreadpoolController"]
# One can get runtime errors or even segfaults due to multiple OpenMP libraries
# loaded simultaneously which can happen easily in Python when importing and
# using compiled extensions built with different compilers and therefore
# different OpenMP runtimes in the same program. In particular libiomp (used by
# Intel ICC) and libomp used by clang/llvm tend to crash. This can happen for
# instance when calling BLAS inside a prange. Setting the following environment
# variable allows multiple OpenMP libraries to be loaded. It should not degrade
# performances since we manually take care of potential over-subscription
# performance issues, in sections of the code where nested OpenMP loops can
# happen, by dynamically reconfiguring the inner OpenMP runtime to temporarily
# disable it while under the scope of the outer OpenMP parallel section.
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "True")
# Structure to cast the info on dynamically loaded library. See
# https://linux.die.net/man/3/dl_iterate_phdr for more details.
_SYSTEM_UINT = ctypes.c_uint64 if sys.maxsize > 2 ** 32 else ctypes.c_uint32
_SYSTEM_UINT_HALF = ctypes.c_uint32 if sys.maxsize > 2 ** 32 else ctypes.c_uint16
class _dl_phdr_info(ctypes.Structure):
_fields_ = [
("dlpi_addr", _SYSTEM_UINT), # Base address of object
("dlpi_name", ctypes.c_char_p), # path to the library
("dlpi_phdr", ctypes.c_void_p), # pointer on dlpi_headers
("dlpi_phnum", _SYSTEM_UINT_HALF), # number of elements in dlpi_phdr
]
# The RTLD_NOLOAD flag for loading shared libraries is not defined on Windows.
try:
_RTLD_NOLOAD = os.RTLD_NOLOAD
except AttributeError:
_RTLD_NOLOAD = ctypes.DEFAULT_MODE
# List of the supported libraries. The items are indexed by the name of the
# class to instanciate to create the library controller objects. The items hold
# the possible prefixes of loaded shared objects, the name of the internal_api
# to call and the name of the user_api.
_SUPPORTED_LIBRARIES = {
"OpenMPController": {
"user_api": "openmp",
"internal_api": "openmp",
"filename_prefixes": ("libiomp", "libgomp", "libomp", "vcomp"),
},
"OpenBLASController": {
"user_api": "blas",
"internal_api": "openblas",
"filename_prefixes": ("libopenblas",),
},
"MKLController": {
"user_api": "blas",
"internal_api": "mkl",
"filename_prefixes": ("libmkl_rt", "mkl_rt"),
},
"BLISController": {
"user_api": "blas",
"internal_api": "blis",
"filename_prefixes": ("libblis",),
},
}
# Helpers for the doc and test names
_ALL_USER_APIS = list(set(lib["user_api"] for lib in _SUPPORTED_LIBRARIES.values()))
_ALL_INTERNAL_APIS = [lib["internal_api"] for lib in _SUPPORTED_LIBRARIES.values()]
_ALL_PREFIXES = [
prefix
for lib in _SUPPORTED_LIBRARIES.values()
for prefix in lib["filename_prefixes"]
]
_ALL_BLAS_LIBRARIES = [
lib["internal_api"]
for lib in _SUPPORTED_LIBRARIES.values()
if lib["user_api"] == "blas"
]
_ALL_OPENMP_LIBRARIES = list(
_SUPPORTED_LIBRARIES["OpenMPController"]["filename_prefixes"]
)
def _format_docstring(*args, **kwargs):
def decorator(o):
if o.__doc__ is not None:
o.__doc__ = o.__doc__.format(*args, **kwargs)
return o
return decorator
@lru_cache(maxsize=10000)
def _realpath(filepath):
"""Small caching wrapper around os.path.realpath to limit system calls"""
return os.path.realpath(filepath)
@_format_docstring(USER_APIS=list(_ALL_USER_APIS), INTERNAL_APIS=_ALL_INTERNAL_APIS)
def threadpool_info():
"""Return the maximal number of threads for each detected library.
Return a list with all the supported libraries that have been found. Each
library is represented by a dict with the following information:
- "user_api" : user API. Possible values are {USER_APIS}.
- "internal_api": internal API. Possible values are {INTERNAL_APIS}.
- "prefix" : filename prefix of the specific implementation.
- "filepath": path to the loaded library.
- "version": version of the library (if available).
- "num_threads": the current thread limit.
In addition, each library may contain internal_api specific entries.
"""
return ThreadpoolController().info()
class _ThreadpoolLimiter:
"""The guts of ThreadpoolController.limit
Refer to the docstring of ThreadpoolController.limit for more details.
It will only act on the library controllers held by the provided `controller`.
Using the default constructor sets the limits right away such that it can be used as
a callable. Setting the limits can be delayed by using the `wrap` class method such
that it can be used as a decorator.
"""
def __init__(self, controller, *, limits=None, user_api=None):
self._limits, self._user_api, self._prefixes = self._check_params(
limits, user_api
)
self._controller = controller
self._original_info = self._controller.info()
self._set_threadpool_limits()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.restore_original_limits()
@classmethod
def wrap(cls, controller, *, limits=None, user_api=None):
"""Return an instance of this class that can be used as a decorator"""
return _ThreadpoolLimiterDecorator(
controller=controller, limits=limits, user_api=user_api
)
def restore_original_limits(self):
"""Set the limits back to their original values"""
for lib_controller, original_info in zip(
self._controller.lib_controllers, self._original_info
):
lib_controller.set_num_threads(original_info["num_threads"])
# Alias of `restore_original_limits` for backward compatibility
unregister = restore_original_limits
def get_original_num_threads(self):
"""Original num_threads from before calling threadpool_limits
Return a dict `{user_api: num_threads}`.
"""
num_threads = {}
warning_apis = []
for user_api in self._user_api:
limits = [
lib_info["num_threads"]
for lib_info in self._original_info
if lib_info["user_api"] == user_api
]
limits = set(limits)
n_limits = len(limits)
if n_limits == 1:
limit = limits.pop()
elif n_limits == 0:
limit = None
else:
limit = min(limits)
warning_apis.append(user_api)
num_threads[user_api] = limit
if warning_apis:
warnings.warn(
"Multiple value possible for following user apis: "
+ ", ".join(warning_apis)
+ ". Returning the minimum."
)
return num_threads
def _check_params(self, limits, user_api):
"""Suitable values for the _limits, _user_api and _prefixes attributes"""
if limits is None or isinstance(limits, int):
if user_api is None:
user_api = _ALL_USER_APIS
elif user_api in _ALL_USER_APIS:
user_api = [user_api]
else:
raise ValueError(
f"user_api must be either in {_ALL_USER_APIS} or None. Got "
f"{user_api} instead."
)
if limits is not None:
limits = {api: limits for api in user_api}
prefixes = []
else:
if isinstance(limits, list):
# This should be a list of dicts of library info, for
# compatibility with the result from threadpool_info.
limits = {
lib_info["prefix"]: lib_info["num_threads"] for lib_info in limits
}
elif isinstance(limits, ThreadpoolController):
# To set the limits from the library controllers of a
# ThreadpoolController object.
limits = {
lib_controller.prefix: lib_controller.num_threads
for lib_controller in limits.lib_controllers
}
if not isinstance(limits, dict):
raise TypeError(
"limits must either be an int, a list or a "
f"dict. Got {type(limits)} instead"
)
# With a dictionary, can set both specific limit for given
# libraries and global limit for user_api. Fetch each separately.
prefixes = [prefix for prefix in limits if prefix in _ALL_PREFIXES]
user_api = [api for api in limits if api in _ALL_USER_APIS]
return limits, user_api, prefixes
def _set_threadpool_limits(self):
"""Change the maximal number of threads in selected thread pools.
Return a list with all the supported libraries that have been found
matching `self._prefixes` and `self._user_api`.
"""
if self._limits is None:
return
for lib_controller in self._controller.lib_controllers:
# self._limits is a dict {key: num_threads} where key is either
# a prefix or a user_api. If a library matches both, the limit
# corresponding to the prefix is chosen.
if lib_controller.prefix in self._limits:
num_threads = self._limits[lib_controller.prefix]
elif lib_controller.user_api in self._limits:
num_threads = self._limits[lib_controller.user_api]
else:
continue
if num_threads is not None:
lib_controller.set_num_threads(num_threads)
class _ThreadpoolLimiterDecorator(_ThreadpoolLimiter, ContextDecorator):
"""Same as _ThreadpoolLimiter but to be used as a decorator"""
def __init__(self, controller, *, limits=None, user_api=None):
self._limits, self._user_api, self._prefixes = self._check_params(
limits, user_api
)
self._controller = controller
def __enter__(self):
# we need to set the limits here and not in the __init__ because we want the
# limits to be set when calling the decorated function, not when creating the
# decorator.
self._original_info = self._controller.info()
self._set_threadpool_limits()
return self
@_format_docstring(
USER_APIS=", ".join(f'"{api}"' for api in _ALL_USER_APIS),
BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
)
class threadpool_limits(_ThreadpoolLimiter):
"""Change the maximal number of threads that can be used in thread pools.
This object can be used either as a callable (the construction of this object
limits the number of threads), as a context manager in a `with` block to
automatically restore the original state of the controlled libraries when exiting
the block, or as a decorator through its `wrap` method.
Set the maximal number of threads that can be used in thread pools used in
the supported libraries to `limit`. This function works for libraries that
are already loaded in the interpreter and can be changed dynamically.
This effect is global and impacts the whole Python process. There is no thread level
isolation as these libraries do not offer thread-local APIs to configure the number
of threads to use in nested parallel calls.
Parameters
----------
limits : int, dict or None (default=None)
The maximal number of threads that can be used in thread pools
- If int, sets the maximum number of threads to `limits` for each
library selected by `user_api`.
- If it is a dictionary `{{key: max_threads}}`, this function sets a
custom maximum number of threads for each `key` which can be either a
`user_api` or a `prefix` for a specific library.
- If None, this function does not do anything.
user_api : {USER_APIS} or None (default=None)
APIs of libraries to limit. Used only if `limits` is an int.
- If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
- If "openmp", it will only limit OpenMP supported libraries
({OPENMP_LIBS}). Note that it can affect the number of threads used
by the BLAS libraries if they rely on OpenMP.
- If None, this function will apply to all supported libraries.
"""
def __init__(self, limits=None, user_api=None):
super().__init__(ThreadpoolController(), limits=limits, user_api=user_api)
@classmethod
def wrap(cls, limits=None, user_api=None):
return super().wrap(ThreadpoolController(), limits=limits, user_api=user_api)
@_format_docstring(
PREFIXES=", ".join(f'"{prefix}"' for prefix in _ALL_PREFIXES),
USER_APIS=", ".join(f'"{api}"' for api in _ALL_USER_APIS),
BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
)
class ThreadpoolController:
"""Collection of LibController objects for all loaded supported libraries
Attributes
----------
lib_controllers : list of `LibController` objects
The list of library controllers of all loaded supported libraries.
"""
# Cache for libc under POSIX and a few system libraries under Windows.
# We use a class level cache instead of an instance level cache because
# it's very unlikely that a shared library will be unloaded and reloaded
# during the lifetime of a program.
_system_libraries = dict()
def __init__(self):
self.lib_controllers = []
self._load_libraries()
self._warn_if_incompatible_openmp()
@classmethod
def _from_controllers(cls, lib_controllers):
new_controller = cls.__new__(cls)
new_controller.lib_controllers = lib_controllers
return new_controller
def info(self):
"""Return lib_controllers info as a list of dicts"""
return [lib_controller.info() for lib_controller in self.lib_controllers]
def select(self, **kwargs):
"""Return a ThreadpoolController containing a subset of its current
library controllers
It will select all libraries matching at least one pair (key, value) from kwargs
where key is an entry of the library info dict (like "user_api", "internal_api",
"prefix", ...) and value is the value or a list of acceptable values for that
entry.
For instance, `ThreadpoolController().select(internal_api=["blis", "openblas"])`
will select all library controllers whose internal_api is either "blis" or
"openblas".
"""
for key, vals in kwargs.items():
kwargs[key] = [vals] if not isinstance(vals, list) else vals
lib_controllers = [
lib_controller
for lib_controller in self.lib_controllers
if any(
getattr(lib_controller, key, None) in vals
for key, vals in kwargs.items()
)
]
return ThreadpoolController._from_controllers(lib_controllers)
@_format_docstring(
USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS),
BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
)
def limit(self, *, limits=None, user_api=None):
"""Change the maximal number of threads that can be used in thread pools.
This function returns an object that can be used either as a callable (the
construction of this object limits the number of threads) or as a context
manager, in a `with` block to automatically restore the original state of the
controlled libraries when exiting the block.
Set the maximal number of threads that can be used in thread pools used in
the supported libraries to `limits`. This function works for libraries that
are already loaded in the interpreter and can be changed dynamically.
This effect is global and impacts the whole Python process. There is no thread
level isolation as these libraries do not offer thread-local APIs to configure
the number of threads to use in nested parallel calls.
Parameters
----------
limits : int, dict or None (default=None)
The maximal number of threads that can be used in thread pools
- If int, sets the maximum number of threads to `limits` for each
library selected by `user_api`.
- If it is a dictionary `{{key: max_threads}}`, this function sets a
custom maximum number of threads for each `key` which can be either a
`user_api` or a `prefix` for a specific library.
- If None, this function does not do anything.
user_api : {USER_APIS} or None (default=None)
APIs of libraries to limit. Used only if `limits` is an int.
- If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
- If "openmp", it will only limit OpenMP supported libraries
({OPENMP_LIBS}). Note that it can affect the number of threads used
by the BLAS libraries if they rely on OpenMP.
- If None, this function will apply to all supported libraries.
"""
return _ThreadpoolLimiter(self, limits=limits, user_api=user_api)
@_format_docstring(
USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS),
BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
)
def wrap(self, *, limits=None, user_api=None):
"""Change the maximal number of threads that can be used in thread pools.
This function returns an object that can be used as a decorator.
Set the maximal number of threads that can be used in thread pools used in
the supported libraries to `limits`. This function works for libraries that
are already loaded in the interpreter and can be changed dynamically.
Parameters
----------
limits : int, dict or None (default=None)
The maximal number of threads that can be used in thread pools
- If int, sets the maximum number of threads to `limits` for each
library selected by `user_api`.
- If it is a dictionary `{{key: max_threads}}`, this function sets a
custom maximum number of threads for each `key` which can be either a
`user_api` or a `prefix` for a specific library.
- If None, this function does not do anything.
user_api : {USER_APIS} or None (default=None)
APIs of libraries to limit. Used only if `limits` is an int.
- If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
- If "openmp", it will only limit OpenMP supported libraries
({OPENMP_LIBS}). Note that it can affect the number of threads used
by the BLAS libraries if they rely on OpenMP.
- If None, this function will apply to all supported libraries.
"""
return _ThreadpoolLimiter.wrap(self, limits=limits, user_api=user_api)
def __len__(self):
return len(self.lib_controllers)
def _load_libraries(self):
"""Loop through loaded shared libraries and store the supported ones"""
if sys.platform == "darwin":
self._find_libraries_with_dyld()
elif sys.platform == "win32":
self._find_libraries_with_enum_process_module_ex()
else:
self._find_libraries_with_dl_iterate_phdr()
def _find_libraries_with_dl_iterate_phdr(self):
"""Loop through loaded libraries and return binders on supported ones
This function is expected to work on POSIX system only.
This code is adapted from code by Intel developper @anton-malakhov
available at https://github.com/IntelPython/smp
Copyright (c) 2017, Intel Corporation published under the BSD 3-Clause
license
"""
libc = self._get_libc()
if not hasattr(libc, "dl_iterate_phdr"): # pragma: no cover
return []
# Callback function for `dl_iterate_phdr` which is called for every
# library loaded in the current process until it returns 1.
def match_library_callback(info, size, data):
# Get the path of the current library
filepath = info.contents.dlpi_name
if filepath:
filepath = filepath.decode("utf-8")
# Store the library controller if it is supported and selected
self._make_controller_from_path(filepath)
return 0
c_func_signature = ctypes.CFUNCTYPE(
ctypes.c_int, # Return type
ctypes.POINTER(_dl_phdr_info),
ctypes.c_size_t,
ctypes.c_char_p,
)
c_match_library_callback = c_func_signature(match_library_callback)
data = ctypes.c_char_p(b"")
libc.dl_iterate_phdr(c_match_library_callback, data)
def _find_libraries_with_dyld(self):
"""Loop through loaded libraries and return binders on supported ones
This function is expected to work on OSX system only
"""
libc = self._get_libc()
if not hasattr(libc, "_dyld_image_count"): # pragma: no cover
return []
n_dyld = libc._dyld_image_count()
libc._dyld_get_image_name.restype = ctypes.c_char_p
for i in range(n_dyld):
filepath = ctypes.string_at(libc._dyld_get_image_name(i))
filepath = filepath.decode("utf-8")
# Store the library controller if it is supported and selected
self._make_controller_from_path(filepath)
def _find_libraries_with_enum_process_module_ex(self):
"""Loop through loaded libraries and return binders on supported ones
This function is expected to work on windows system only.
This code is adapted from code by Philipp Hagemeister @phihag available
at https://stackoverflow.com/questions/17474574
"""
from ctypes.wintypes import DWORD, HMODULE, MAX_PATH
PROCESS_QUERY_INFORMATION = 0x0400
PROCESS_VM_READ = 0x0010
LIST_LIBRARIES_ALL = 0x03
ps_api = self._get_windll("Psapi")
kernel_32 = self._get_windll("kernel32")
h_process = kernel_32.OpenProcess(
PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, False, os.getpid()
)
if not h_process: # pragma: no cover
raise OSError(f"Could not open PID {os.getpid()}")
try:
buf_count = 256
needed = DWORD()
# Grow the buffer until it becomes large enough to hold all the
# module headers
while True:
buf = (HMODULE * buf_count)()
buf_size = ctypes.sizeof(buf)
if not ps_api.EnumProcessModulesEx(
h_process,
ctypes.byref(buf),
buf_size,
ctypes.byref(needed),
LIST_LIBRARIES_ALL,
):
raise OSError("EnumProcessModulesEx failed")
if buf_size >= needed.value:
break
buf_count = needed.value // (buf_size // buf_count)
count = needed.value // (buf_size // buf_count)
h_modules = map(HMODULE, buf[:count])
# Loop through all the module headers and get the library path
buf = ctypes.create_unicode_buffer(MAX_PATH)
n_size = DWORD()
for h_module in h_modules:
# Get the path of the current module
if not ps_api.GetModuleFileNameExW(
h_process, h_module, ctypes.byref(buf), ctypes.byref(n_size)
):
raise OSError("GetModuleFileNameEx failed")
filepath = buf.value
# Store the library controller if it is supported and selected
self._make_controller_from_path(filepath)
finally:
kernel_32.CloseHandle(h_process)
def _make_controller_from_path(self, filepath):
"""Store a library controller if it is supported and selected"""
# Required to resolve symlinks
filepath = _realpath(filepath)
# `lower` required to take account of OpenMP dll case on Windows
# (vcomp, VCOMP, Vcomp, ...)
filename = os.path.basename(filepath).lower()
# Loop through supported libraries to find if this filename corresponds
# to a supported one.
for controller_class, candidate_lib in _SUPPORTED_LIBRARIES.items():
# check if filename matches a supported prefix
prefix = self._check_prefix(filename, candidate_lib["filename_prefixes"])
# filename does not match any of the prefixes of the candidate
# library. move to next library.
if prefix is None:
continue
# filename matches a prefix. Create and store the library
# controller.
user_api = candidate_lib["user_api"]
internal_api = candidate_lib["internal_api"]
lib_controller_class = globals()[controller_class]
lib_controller = lib_controller_class(
filepath=filepath,
prefix=prefix,
user_api=user_api,
internal_api=internal_api,
)
self.lib_controllers.append(lib_controller)
def _check_prefix(self, library_basename, filename_prefixes):
"""Return the prefix library_basename starts with
Return None if none matches.
"""
for prefix in filename_prefixes:
if library_basename.startswith(prefix):
return prefix
return None
def _warn_if_incompatible_openmp(self):
"""Raise a warning if llvm-OpenMP and intel-OpenMP are both loaded"""
if sys.platform != "linux":
# Only raise the warning on linux
return
prefixes = [lib_controller.prefix for lib_controller in self.lib_controllers]
msg = textwrap.dedent(
"""
Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md
"""
)
if "libomp" in prefixes and "libiomp" in prefixes:
warnings.warn(msg, RuntimeWarning)
@classmethod
def _get_libc(cls):
"""Load the lib-C for unix systems."""
libc = cls._system_libraries.get("libc")
if libc is None:
libc_name = find_library("c")
if libc_name is None: # pragma: no cover
return None
libc = ctypes.CDLL(libc_name, mode=_RTLD_NOLOAD)
cls._system_libraries["libc"] = libc
return libc
@classmethod
def _get_windll(cls, dll_name):
"""Load a windows DLL"""
dll = cls._system_libraries.get(dll_name)
if dll is None:
dll = ctypes.WinDLL(f"{dll_name}.dll")
cls._system_libraries[dll_name] = dll
return dll
@_format_docstring(
USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS),
INTERNAL_APIS=", ".join('"{}"'.format(api) for api in _ALL_INTERNAL_APIS),
)
class LibController(ABC):
"""Abstract base class for the individual library controllers
A library controller is represented by the following information:
- "user_api" : user API. Possible values are {USER_APIS}.
- "internal_api" : internal API. Possible values are {INTERNAL_APIS}.
- "prefix" : prefix of the shared library's name.
- "filepath" : path to the loaded library.
- "version" : version of the library (if available).
- "num_threads" : the current thread limit.
In addition, each library controller may contain internal_api specific
entries.
"""
def __init__(self, *, filepath=None, prefix=None, user_api=None, internal_api=None):
self.user_api = user_api
self.internal_api = internal_api
self.prefix = prefix
self.filepath = filepath
self._dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD)
self.version = self.get_version()
def info(self):
"""Return relevant info wrapped in a dict"""
all_attrs = dict(vars(self), **{"num_threads": self.num_threads})
return {k: v for k, v in all_attrs.items() if not k.startswith("_")}
@property
def num_threads(self):
return self.get_num_threads()
@abstractmethod
def get_num_threads(self):
"""Return the maximum number of threads available to use"""
pass # pragma: no cover
@abstractmethod
def set_num_threads(self, num_threads):
"""Set the maximum number of threads to use"""
pass # pragma: no cover
@abstractmethod
def get_version(self):
"""Return the version of the shared library"""
pass # pragma: no cover
class OpenBLASController(LibController):
"""Controller class for OpenBLAS"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.threading_layer = self._get_threading_layer()
self.architecture = self._get_architecture()
def get_num_threads(self):
get_func = getattr(
self._dynlib,
"openblas_get_num_threads",
# Symbols differ when built for 64bit integers in Fortran
getattr(self._dynlib, "openblas_get_num_threads64_", lambda: None),
)
return get_func()
def set_num_threads(self, num_threads):
set_func = getattr(
self._dynlib,
"openblas_set_num_threads",
# Symbols differ when built for 64bit integers in Fortran
getattr(
self._dynlib, "openblas_set_num_threads64_", lambda num_threads: None
),
)
return set_func(num_threads)
def get_version(self):
# None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS
# did not expose its version before that.
get_config = getattr(
self._dynlib,
"openblas_get_config",
getattr(self._dynlib, "openblas_get_config64_", None),
)
if get_config is None:
return None
get_config.restype = ctypes.c_char_p
config = get_config().split()
if config[0] == b"OpenBLAS":
return config[1].decode("utf-8")
return None
def _get_threading_layer(self):
"""Return the threading layer of OpenBLAS"""
openblas_get_parallel = getattr(
self._dynlib,
"openblas_get_parallel",
getattr(self._dynlib, "openblas_get_parallel64_", None),
)
if openblas_get_parallel is None:
return "unknown"
threading_layer = openblas_get_parallel()
if threading_layer == 2:
return "openmp"
elif threading_layer == 1:
return "pthreads"
return "disabled"
def _get_architecture(self):
"""Return the architecture detected by OpenBLAS"""
get_corename = getattr(
self._dynlib,
"openblas_get_corename",
getattr(self._dynlib, "openblas_get_corename64_", None),
)
if get_corename is None:
return None
get_corename.restype = ctypes.c_char_p
return get_corename().decode("utf-8")
class BLISController(LibController):
"""Controller class for BLIS"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.threading_layer = self._get_threading_layer()
self.architecture = self._get_architecture()
def get_num_threads(self):
get_func = getattr(self._dynlib, "bli_thread_get_num_threads", lambda: None)
num_threads = get_func()
# by default BLIS is single-threaded and get_num_threads
# returns -1. We map it to 1 for consistency with other libraries.
return 1 if num_threads == -1 else num_threads
def set_num_threads(self, num_threads):
set_func = getattr(
self._dynlib, "bli_thread_set_num_threads", lambda num_threads: None
)
return set_func(num_threads)
def get_version(self):
get_version_ = getattr(self._dynlib, "bli_info_get_version_str", None)
if get_version_ is None:
return None
get_version_.restype = ctypes.c_char_p
return get_version_().decode("utf-8")
def _get_threading_layer(self):
"""Return the threading layer of BLIS"""
if self._dynlib.bli_info_get_enable_openmp():
return "openmp"
elif self._dynlib.bli_info_get_enable_pthreads():
return "pthreads"
return "disabled"
def _get_architecture(self):
"""Return the architecture detected by BLIS"""
bli_arch_query_id = getattr(self._dynlib, "bli_arch_query_id", None)
bli_arch_string = getattr(self._dynlib, "bli_arch_string", None)
if bli_arch_query_id is None or bli_arch_string is None:
return None
# the true restype should be BLIS' arch_t (enum) but int should work
# for us:
bli_arch_query_id.restype = ctypes.c_int
bli_arch_string.restype = ctypes.c_char_p
return bli_arch_string(bli_arch_query_id()).decode("utf-8")
class MKLController(LibController):
"""Controller class for MKL"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.threading_layer = self._get_threading_layer()
def get_num_threads(self):
get_func = getattr(self._dynlib, "MKL_Get_Max_Threads", lambda: None)
return get_func()
def set_num_threads(self, num_threads):
set_func = getattr(
self._dynlib, "MKL_Set_Num_Threads", lambda num_threads: None
)
return set_func(num_threads)
def get_version(self):
if not hasattr(self._dynlib, "MKL_Get_Version_String"):
return None
res = ctypes.create_string_buffer(200)
self._dynlib.MKL_Get_Version_String(res, 200)
version = res.value.decode("utf-8")
group = re.search(r"Version ([^ ]+) ", version)
if group is not None:
version = group.groups()[0]
return version.strip()
def _get_threading_layer(self):
"""Return the threading layer of MKL"""
# The function mkl_set_threading_layer returns the current threading
# layer. Calling it with an invalid threading layer allows us to safely
# get the threading layer
set_threading_layer = getattr(
self._dynlib, "MKL_Set_Threading_Layer", lambda layer: -1
)
layer_map = {
0: "intel",
1: "sequential",
2: "pgi",
3: "gnu",
4: "tbb",
-1: "not specified",
}
return layer_map[set_threading_layer(-1)]
class OpenMPController(LibController):
"""Controller class for OpenMP"""
def get_num_threads(self):
get_func = getattr(self._dynlib, "omp_get_max_threads", lambda: None)
return get_func()
def set_num_threads(self, num_threads):
set_func = getattr(
self._dynlib, "omp_set_num_threads", lambda num_threads: None
)
return set_func(num_threads)
def get_version(self):
# There is no way to get the version number programmatically in OpenMP.
return None
def _main():
"""Commandline interface to display thread-pool information and exit."""
import argparse
import importlib
import json
import sys
parser = argparse.ArgumentParser(
usage="python -m threadpoolctl -i numpy scipy.linalg xgboost",
description="Display thread-pool information and exit.",
)
parser.add_argument(
"-i",
"--import",
dest="modules",
nargs="*",
default=(),
help="Python modules to import before introspecting thread-pools.",
)
parser.add_argument(
"-c",
"--command",
help="a Python statement to execute before introspecting thread-pools.",
)
options = parser.parse_args(sys.argv[1:])
for module in options.modules:
try:
importlib.import_module(module, package=None)
except ImportError:
print("WARNING: could not import", module, file=sys.stderr)
if options.command:
exec(options.command)
print(json.dumps(threadpool_info(), indent=2))
if __name__ == "__main__":
_main()
import abc
import collections
import collections.abc
import operator
import sys
import typing
# After PEP 560, internal typing API was substantially reworked.
# This is especially important for Protocol class which uses internal APIs
# quite extensively.
PEP_560 = sys.version_info[:3] >= (3, 7, 0)
if PEP_560:
GenericMeta = type
else:
# 3.6
from typing import GenericMeta, _type_vars # noqa
# The two functions below are copies of typing internal helpers.
# They are needed by _ProtocolMeta
def _no_slots_copy(dct):
dict_copy = dict(dct)
if '__slots__' in dict_copy:
for slot in dict_copy['__slots__']:
dict_copy.pop(slot, None)
return dict_copy
def _check_generic(cls, parameters):
if not cls.__parameters__:
raise TypeError(f"{cls} is not a generic class")
alen = len(parameters)
elen = len(cls.__parameters__)
if alen != elen:
raise TypeError(f"Too {'many' if alen > elen else 'few'} arguments for {cls};"
f" actual {alen}, expected {elen}")
# Please keep __all__ alphabetized within each category.
__all__ = [
# Super-special typing primitives.
'ClassVar',
'Concatenate',
'Final',
'ParamSpec',
'Self',
'Type',
# ABCs (from collections.abc).
'Awaitable',
'AsyncIterator',
'AsyncIterable',
'Coroutine',
'AsyncGenerator',
'AsyncContextManager',
'ChainMap',
# Concrete collection types.
'ContextManager',
'Counter',
'Deque',
'DefaultDict',
'OrderedDict',
'TypedDict',
# Structural checks, a.k.a. protocols.
'SupportsIndex',
# One-off things.
'Annotated',
'final',
'IntVar',
'Literal',
'NewType',
'overload',
'Protocol',
'runtime',
'runtime_checkable',
'Text',
'TypeAlias',
'TypeGuard',
'TYPE_CHECKING',
]
if PEP_560:
__all__.extend(["get_args", "get_origin", "get_type_hints"])
# 3.6.2+
if hasattr(typing, 'NoReturn'):
NoReturn = typing.NoReturn
# 3.6.0-3.6.1
else:
class _NoReturn(typing._FinalTypingBase, _root=True):
"""Special type indicating functions that never return.
Example::
from typing import NoReturn
def stop() -> NoReturn:
raise Exception('no way')
This type is invalid in other positions, e.g., ``List[NoReturn]``
will fail in static type checkers.
"""
__slots__ = ()
def __instancecheck__(self, obj):
raise TypeError("NoReturn cannot be used with isinstance().")
def __subclasscheck__(self, cls):
raise TypeError("NoReturn cannot be used with issubclass().")
NoReturn = _NoReturn(_root=True)
# Some unconstrained type variables. These are used by the container types.
# (These are not for export.)
T = typing.TypeVar('T') # Any type.
KT = typing.TypeVar('KT') # Key type.
VT = typing.TypeVar('VT') # Value type.
T_co = typing.TypeVar('T_co', covariant=True) # Any type covariant containers.
T_contra = typing.TypeVar('T_contra', contravariant=True) # Ditto contravariant.
ClassVar = typing.ClassVar
# On older versions of typing there is an internal class named "Final".
# 3.8+
if hasattr(typing, 'Final') and sys.version_info[:2] >= (3, 7):
Final = typing.Final
# 3.7
elif sys.version_info[:2] >= (3, 7):
class _FinalForm(typing._SpecialForm, _root=True):
def __repr__(self):
return 'typing_extensions.' + self._name
def __getitem__(self, parameters):
item = typing._type_check(parameters,
f'{self._name} accepts only single type')
return typing._GenericAlias(self, (item,))
Final = _FinalForm('Final',
doc="""A special typing construct to indicate that a name
cannot be re-assigned or overridden in a subclass.
For example:
MAX_SIZE: Final = 9000
MAX_SIZE += 1 # Error reported by type checker
class Connection:
TIMEOUT: Final[int] = 10
class FastConnector(Connection):
TIMEOUT = 1 # Error reported by type checker
There is no runtime checking of these properties.""")
# 3.6
else:
class _Final(typing._FinalTypingBase, _root=True):
"""A special typing construct to indicate that a name
cannot be re-assigned or overridden in a subclass.
For example:
MAX_SIZE: Final = 9000
MAX_SIZE += 1 # Error reported by type checker
class Connection:
TIMEOUT: Final[int] = 10
class FastConnector(Connection):
TIMEOUT = 1 # Error reported by type checker
There is no runtime checking of these properties.
"""
__slots__ = ('__type__',)
def __init__(self, tp=None, **kwds):
self.__type__ = tp
def __getitem__(self, item):
cls = type(self)
if self.__type__ is None:
return cls(typing._type_check(item,
f'{cls.__name__[1:]} accepts only single type.'),
_root=True)
raise TypeError(f'{cls.__name__[1:]} cannot be further subscripted')
def _eval_type(self, globalns, localns):
new_tp = typing._eval_type(self.__type__, globalns, localns)
if new_tp == self.__type__:
return self
return type(self)(new_tp, _root=True)
def __repr__(self):
r = super().__repr__()
if self.__type__ is not None:
r += f'[{typing._type_repr(self.__type__)}]'
return r
def __hash__(self):
return hash((type(self).__name__, self.__type__))
def __eq__(self, other):
if not isinstance(other, _Final):
return NotImplemented
if self.__type__ is not None:
return self.__type__ == other.__type__
return self is other
Final = _Final(_root=True)
# 3.8+
if hasattr(typing, 'final'):
final = typing.final
# 3.6-3.7
else:
def final(f):
"""This decorator can be used to indicate to type checkers that
the decorated method cannot be overridden, and decorated class
cannot be subclassed. For example:
class Base:
@final
def done(self) -> None:
...
class Sub(Base):
def done(self) -> None: # Error reported by type checker
...
@final
class Leaf:
...
class Other(Leaf): # Error reported by type checker
...
There is no runtime checking of these properties.
"""
return f
def IntVar(name):
return typing.TypeVar(name)
# 3.8+:
if hasattr(typing, 'Literal'):
Literal = typing.Literal
# 3.7:
elif sys.version_info[:2] >= (3, 7):
class _LiteralForm(typing._SpecialForm, _root=True):
def __repr__(self):
return 'typing_extensions.' + self._name
def __getitem__(self, parameters):
return typing._GenericAlias(self, parameters)
Literal = _LiteralForm('Literal',
doc="""A type that can be used to indicate to type checkers
that the corresponding value has a value literally equivalent
to the provided parameter. For example:
var: Literal[4] = 4
The type checker understands that 'var' is literally equal to
the value 4 and no other value.
Literal[...] cannot be subclassed. There is no runtime
checking verifying that the parameter is actually a value
instead of a type.""")
# 3.6:
else:
class _Literal(typing._FinalTypingBase, _root=True):
"""A type that can be used to indicate to type checkers that the
corresponding value has a value literally equivalent to the
provided parameter. For example:
var: Literal[4] = 4
The type checker understands that 'var' is literally equal to the
value 4 and no other value.
Literal[...] cannot be subclassed. There is no runtime checking
verifying that the parameter is actually a value instead of a type.
"""
__slots__ = ('__values__',)
def __init__(self, values=None, **kwds):
self.__values__ = values
def __getitem__(self, values):
cls = type(self)
if self.__values__ is None:
if not isinstance(values, tuple):
values = (values,)
return cls(values, _root=True)
raise TypeError(f'{cls.__name__[1:]} cannot be further subscripted')
def _eval_type(self, globalns, localns):
return self
def __repr__(self):
r = super().__repr__()
if self.__values__ is not None:
r += f'[{", ".join(map(typing._type_repr, self.__values__))}]'
return r
def __hash__(self):
return hash((type(self).__name__, self.__values__))
def __eq__(self, other):
if not isinstance(other, _Literal):
return NotImplemented
if self.__values__ is not None:
return self.__values__ == other.__values__
return self is other
Literal = _Literal(_root=True)
_overload_dummy = typing._overload_dummy # noqa
overload = typing.overload
# This is not a real generic class. Don't use outside annotations.
Type = typing.Type
# Various ABCs mimicking those in collections.abc.
# A few are simply re-exported for completeness.
class _ExtensionsGenericMeta(GenericMeta):
def __subclasscheck__(self, subclass):
"""This mimics a more modern GenericMeta.__subclasscheck__() logic
(that does not have problems with recursion) to work around interactions
between collections, typing, and typing_extensions on older
versions of Python, see https://github.com/python/typing/issues/501.
"""
if self.__origin__ is not None:
if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']:
raise TypeError("Parameterized generics cannot be used with class "
"or instance checks")
return False
if not self.__extra__:
return super().__subclasscheck__(subclass)
res = self.__extra__.__subclasshook__(subclass)
if res is not NotImplemented:
return res
if self.__extra__ in subclass.__mro__:
return True
for scls in self.__extra__.__subclasses__():
if isinstance(scls, GenericMeta):
continue
if issubclass(subclass, scls):
return True
return False
Awaitable = typing.Awaitable
Coroutine = typing.Coroutine
AsyncIterable = typing.AsyncIterable
AsyncIterator = typing.AsyncIterator
# 3.6.1+
if hasattr(typing, 'Deque'):
Deque = typing.Deque
# 3.6.0
else:
class Deque(collections.deque, typing.MutableSequence[T],
metaclass=_ExtensionsGenericMeta,
extra=collections.deque):
__slots__ = ()
def __new__(cls, *args, **kwds):
if cls._gorg is Deque:
return collections.deque(*args, **kwds)
return typing._generic_new(collections.deque, cls, *args, **kwds)
ContextManager = typing.ContextManager
# 3.6.2+
if hasattr(typing, 'AsyncContextManager'):
AsyncContextManager = typing.AsyncContextManager
# 3.6.0-3.6.1
else:
from _collections_abc import _check_methods as _check_methods_in_mro # noqa
class AsyncContextManager(typing.Generic[T_co]):
__slots__ = ()
async def __aenter__(self):
return self
@abc.abstractmethod
async def __aexit__(self, exc_type, exc_value, traceback):
return None
@classmethod
def __subclasshook__(cls, C):
if cls is AsyncContextManager:
return _check_methods_in_mro(C, "__aenter__", "__aexit__")
return NotImplemented
DefaultDict = typing.DefaultDict
# 3.7.2+
if hasattr(typing, 'OrderedDict'):
OrderedDict = typing.OrderedDict
# 3.7.0-3.7.2
elif (3, 7, 0) <= sys.version_info[:3] < (3, 7, 2):
OrderedDict = typing._alias(collections.OrderedDict, (KT, VT))
# 3.6
else:
class OrderedDict(collections.OrderedDict, typing.MutableMapping[KT, VT],
metaclass=_ExtensionsGenericMeta,
extra=collections.OrderedDict):
__slots__ = ()
def __new__(cls, *args, **kwds):
if cls._gorg is OrderedDict:
return collections.OrderedDict(*args, **kwds)
return typing._generic_new(collections.OrderedDict, cls, *args, **kwds)
# 3.6.2+
if hasattr(typing, 'Counter'):
Counter = typing.Counter
# 3.6.0-3.6.1
else:
class Counter(collections.Counter,
typing.Dict[T, int],
metaclass=_ExtensionsGenericMeta, extra=collections.Counter):
__slots__ = ()
def __new__(cls, *args, **kwds):
if cls._gorg is Counter:
return collections.Counter(*args, **kwds)
return typing._generic_new(collections.Counter, cls, *args, **kwds)
# 3.6.1+
if hasattr(typing, 'ChainMap'):
ChainMap = typing.ChainMap
elif hasattr(collections, 'ChainMap'):
class ChainMap(collections.ChainMap, typing.MutableMapping[KT, VT],
metaclass=_ExtensionsGenericMeta,
extra=collections.ChainMap):
__slots__ = ()
def __new__(cls, *args, **kwds):
if cls._gorg is ChainMap:
return collections.ChainMap(*args, **kwds)
return typing._generic_new(collections.ChainMap, cls, *args, **kwds)
# 3.6.1+
if hasattr(typing, 'AsyncGenerator'):
AsyncGenerator = typing.AsyncGenerator
# 3.6.0
else:
class AsyncGenerator(AsyncIterator[T_co], typing.Generic[T_co, T_contra],
metaclass=_ExtensionsGenericMeta,
extra=collections.abc.AsyncGenerator):
__slots__ = ()
NewType = typing.NewType
Text = typing.Text
TYPE_CHECKING = typing.TYPE_CHECKING
def _gorg(cls):
"""This function exists for compatibility with old typing versions."""
assert isinstance(cls, GenericMeta)
if hasattr(cls, '_gorg'):
return cls._gorg
while cls.__origin__ is not None:
cls = cls.__origin__
return cls
_PROTO_WHITELIST = ['Callable', 'Awaitable',
'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator',
'Hashable', 'Sized', 'Container', 'Collection', 'Reversible',
'ContextManager', 'AsyncContextManager']
def _get_protocol_attrs(cls):
attrs = set()
for base in cls.__mro__[:-1]: # without object
if base.__name__ in ('Protocol', 'Generic'):
continue
annotations = getattr(base, '__annotations__', {})
for attr in list(base.__dict__.keys()) + list(annotations.keys()):
if (not attr.startswith('_abc_') and attr not in (
'__abstractmethods__', '__annotations__', '__weakref__',
'_is_protocol', '_is_runtime_protocol', '__dict__',
'__args__', '__slots__',
'__next_in_mro__', '__parameters__', '__origin__',
'__orig_bases__', '__extra__', '__tree_hash__',
'__doc__', '__subclasshook__', '__init__', '__new__',
'__module__', '_MutableMapping__marker', '_gorg')):
attrs.add(attr)
return attrs
def _is_callable_members_only(cls):
return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls))
# 3.8+
if hasattr(typing, 'Protocol'):
Protocol = typing.Protocol
# 3.7
elif PEP_560:
from typing import _collect_type_vars # noqa
def _no_init(self, *args, **kwargs):
if type(self)._is_protocol:
raise TypeError('Protocols cannot be instantiated')
class _ProtocolMeta(abc.ABCMeta):
# This metaclass is a bit unfortunate and exists only because of the lack
# of __instancehook__.
def __instancecheck__(cls, instance):
# We need this method for situations where attributes are
# assigned in __init__.
if ((not getattr(cls, '_is_protocol', False) or
_is_callable_members_only(cls)) and
issubclass(instance.__class__, cls)):
return True
if cls._is_protocol:
if all(hasattr(instance, attr) and
(not callable(getattr(cls, attr, None)) or
getattr(instance, attr) is not None)
for attr in _get_protocol_attrs(cls)):
return True
return super().__instancecheck__(instance)
class Protocol(metaclass=_ProtocolMeta):
# There is quite a lot of overlapping code with typing.Generic.
# Unfortunately it is hard to avoid this while these live in two different
# modules. The duplicated code will be removed when Protocol is moved to typing.
"""Base class for protocol classes. Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize
structural subtyping (static duck-typing), for example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with
@typing_extensions.runtime act as simple-minded runtime protocol that checks
only the presence of given attributes, ignoring their type signatures.
Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
"""
__slots__ = ()
_is_protocol = True
def __new__(cls, *args, **kwds):
if cls is Protocol:
raise TypeError("Type Protocol cannot be instantiated; "
"it can only be used as a base class")
return super().__new__(cls)
@typing._tp_cache
def __class_getitem__(cls, params):
if not isinstance(params, tuple):
params = (params,)
if not params and cls is not typing.Tuple:
raise TypeError(
f"Parameter list to {cls.__qualname__}[...] cannot be empty")
msg = "Parameters to generic types must be types."
params = tuple(typing._type_check(p, msg) for p in params) # noqa
if cls is Protocol:
# Generic can only be subscripted with unique type variables.
if not all(isinstance(p, typing.TypeVar) for p in params):
i = 0
while isinstance(params[i], typing.TypeVar):
i += 1
raise TypeError(
"Parameters to Protocol[...] must all be type variables."
f" Parameter {i + 1} is {params[i]}")
if len(set(params)) != len(params):
raise TypeError(
"Parameters to Protocol[...] must all be unique")
else:
# Subscripting a regular Generic subclass.
_check_generic(cls, params)
return typing._GenericAlias(cls, params)
def __init_subclass__(cls, *args, **kwargs):
tvars = []
if '__orig_bases__' in cls.__dict__:
error = typing.Generic in cls.__orig_bases__
else:
error = typing.Generic in cls.__bases__
if error:
raise TypeError("Cannot inherit from plain Generic")
if '__orig_bases__' in cls.__dict__:
tvars = _collect_type_vars(cls.__orig_bases__)
# Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn].
# If found, tvars must be a subset of it.
# If not found, tvars is it.
# Also check for and reject plain Generic,
# and reject multiple Generic[...] and/or Protocol[...].
gvars = None
for base in cls.__orig_bases__:
if (isinstance(base, typing._GenericAlias) and
base.__origin__ in (typing.Generic, Protocol)):
# for error messages
the_base = base.__origin__.__name__
if gvars is not None:
raise TypeError(
"Cannot inherit from Generic[...]"
" and/or Protocol[...] multiple types.")
gvars = base.__parameters__
if gvars is None:
gvars = tvars
else:
tvarset = set(tvars)
gvarset = set(gvars)
if not tvarset <= gvarset:
s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
s_args = ', '.join(str(g) for g in gvars)
raise TypeError(f"Some type variables ({s_vars}) are"
f" not listed in {the_base}[{s_args}]")
tvars = gvars
cls.__parameters__ = tuple(tvars)
# Determine if this is a protocol or a concrete subclass.
if not cls.__dict__.get('_is_protocol', None):
cls._is_protocol = any(b is Protocol for b in cls.__bases__)
# Set (or override) the protocol subclass hook.
def _proto_hook(other):
if not cls.__dict__.get('_is_protocol', None):
return NotImplemented
if not getattr(cls, '_is_runtime_protocol', False):
if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']:
return NotImplemented
raise TypeError("Instance and class checks can only be used with"
" @runtime protocols")
if not _is_callable_members_only(cls):
if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']:
return NotImplemented
raise TypeError("Protocols with non-method members"
" don't support issubclass()")
if not isinstance(other, type):
# Same error as for issubclass(1, int)
raise TypeError('issubclass() arg 1 must be a class')
for attr in _get_protocol_attrs(cls):
for base in other.__mro__:
if attr in base.__dict__:
if base.__dict__[attr] is None:
return NotImplemented
break
annotations = getattr(base, '__annotations__', {})
if (isinstance(annotations, typing.Mapping) and
attr in annotations and
isinstance(other, _ProtocolMeta) and
other._is_protocol):
break
else:
return NotImplemented
return True
if '__subclasshook__' not in cls.__dict__:
cls.__subclasshook__ = _proto_hook
# We have nothing more to do for non-protocols.
if not cls._is_protocol:
return
# Check consistency of bases.
for base in cls.__bases__:
if not (base in (object, typing.Generic) or
base.__module__ == 'collections.abc' and
base.__name__ in _PROTO_WHITELIST or
isinstance(base, _ProtocolMeta) and base._is_protocol):
raise TypeError('Protocols can only inherit from other'
f' protocols, got {repr(base)}')
cls.__init__ = _no_init
# 3.6
else:
from typing import _next_in_mro, _type_check # noqa
def _no_init(self, *args, **kwargs):
if type(self)._is_protocol:
raise TypeError('Protocols cannot be instantiated')
class _ProtocolMeta(GenericMeta):
"""Internal metaclass for Protocol.
This exists so Protocol classes can be generic without deriving
from Generic.
"""
def __new__(cls, name, bases, namespace,
tvars=None, args=None, origin=None, extra=None, orig_bases=None):
# This is just a version copied from GenericMeta.__new__ that
# includes "Protocol" special treatment. (Comments removed for brevity.)
assert extra is None # Protocols should not have extra
if tvars is not None:
assert origin is not None
assert all(isinstance(t, typing.TypeVar) for t in tvars), tvars
else:
tvars = _type_vars(bases)
gvars = None
for base in bases:
if base is typing.Generic:
raise TypeError("Cannot inherit from plain Generic")
if (isinstance(base, GenericMeta) and
base.__origin__ in (typing.Generic, Protocol)):
if gvars is not None:
raise TypeError(
"Cannot inherit from Generic[...] or"
" Protocol[...] multiple times.")
gvars = base.__parameters__
if gvars is None:
gvars = tvars
else:
tvarset = set(tvars)
gvarset = set(gvars)
if not tvarset <= gvarset:
s_vars = ", ".join(str(t) for t in tvars if t not in gvarset)
s_args = ", ".join(str(g) for g in gvars)
cls_name = "Generic" if any(b.__origin__ is typing.Generic
for b in bases) else "Protocol"
raise TypeError(f"Some type variables ({s_vars}) are"
f" not listed in {cls_name}[{s_args}]")
tvars = gvars
initial_bases = bases
if (extra is not None and type(extra) is abc.ABCMeta and
extra not in bases):
bases = (extra,) + bases
bases = tuple(_gorg(b) if isinstance(b, GenericMeta) else b
for b in bases)
if any(isinstance(b, GenericMeta) and b is not typing.Generic for b in bases):
bases = tuple(b for b in bases if b is not typing.Generic)
namespace.update({'__origin__': origin, '__extra__': extra})
self = super(GenericMeta, cls).__new__(cls, name, bases, namespace,
_root=True)
super(GenericMeta, self).__setattr__('_gorg',
self if not origin else
_gorg(origin))
self.__parameters__ = tvars
self.__args__ = tuple(... if a is typing._TypingEllipsis else
() if a is typing._TypingEmpty else
a for a in args) if args else None
self.__next_in_mro__ = _next_in_mro(self)
if orig_bases is None:
self.__orig_bases__ = initial_bases
elif origin is not None:
self._abc_registry = origin._abc_registry
self._abc_cache = origin._abc_cache
if hasattr(self, '_subs_tree'):
self.__tree_hash__ = (hash(self._subs_tree()) if origin else
super(GenericMeta, self).__hash__())
return self
def __init__(cls, *args, **kwargs):
super().__init__(*args, **kwargs)
if not cls.__dict__.get('_is_protocol', None):
cls._is_protocol = any(b is Protocol or
isinstance(b, _ProtocolMeta) and
b.__origin__ is Protocol
for b in cls.__bases__)
if cls._is_protocol:
for base in cls.__mro__[1:]:
if not (base in (object, typing.Generic) or
base.__module__ == 'collections.abc' and
base.__name__ in _PROTO_WHITELIST or
isinstance(base, typing.TypingMeta) and base._is_protocol or
isinstance(base, GenericMeta) and
base.__origin__ is typing.Generic):
raise TypeError(f'Protocols can only inherit from other'
f' protocols, got {repr(base)}')
cls.__init__ = _no_init
def _proto_hook(other):
if not cls.__dict__.get('_is_protocol', None):
return NotImplemented
if not isinstance(other, type):
# Same error as for issubclass(1, int)
raise TypeError('issubclass() arg 1 must be a class')
for attr in _get_protocol_attrs(cls):
for base in other.__mro__:
if attr in base.__dict__:
if base.__dict__[attr] is None:
return NotImplemented
break
annotations = getattr(base, '__annotations__', {})
if (isinstance(annotations, typing.Mapping) and
attr in annotations and
isinstance(other, _ProtocolMeta) and
other._is_protocol):
break
else:
return NotImplemented
return True
if '__subclasshook__' not in cls.__dict__:
cls.__subclasshook__ = _proto_hook
def __instancecheck__(self, instance):
# We need this method for situations where attributes are
# assigned in __init__.
if ((not getattr(self, '_is_protocol', False) or
_is_callable_members_only(self)) and
issubclass(instance.__class__, self)):
return True
if self._is_protocol:
if all(hasattr(instance, attr) and
(not callable(getattr(self, attr, None)) or
getattr(instance, attr) is not None)
for attr in _get_protocol_attrs(self)):
return True
return super(GenericMeta, self).__instancecheck__(instance)
def __subclasscheck__(self, cls):
if self.__origin__ is not None:
if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']:
raise TypeError("Parameterized generics cannot be used with class "
"or instance checks")
return False
if (self.__dict__.get('_is_protocol', None) and
not self.__dict__.get('_is_runtime_protocol', None)):
if sys._getframe(1).f_globals['__name__'] in ['abc',
'functools',
'typing']:
return False
raise TypeError("Instance and class checks can only be used with"
" @runtime protocols")
if (self.__dict__.get('_is_runtime_protocol', None) and
not _is_callable_members_only(self)):
if sys._getframe(1).f_globals['__name__'] in ['abc',
'functools',
'typing']:
return super(GenericMeta, self).__subclasscheck__(cls)
raise TypeError("Protocols with non-method members"
" don't support issubclass()")
return super(GenericMeta, self).__subclasscheck__(cls)
@typing._tp_cache
def __getitem__(self, params):
# We also need to copy this from GenericMeta.__getitem__ to get
# special treatment of "Protocol". (Comments removed for brevity.)
if not isinstance(params, tuple):
params = (params,)
if not params and _gorg(self) is not typing.Tuple:
raise TypeError(
f"Parameter list to {self.__qualname__}[...] cannot be empty")
msg = "Parameters to generic types must be types."
params = tuple(_type_check(p, msg) for p in params)
if self in (typing.Generic, Protocol):
if not all(isinstance(p, typing.TypeVar) for p in params):
raise TypeError(
f"Parameters to {repr(self)}[...] must all be type variables")
if len(set(params)) != len(params):
raise TypeError(
f"Parameters to {repr(self)}[...] must all be unique")
tvars = params
args = params
elif self in (typing.Tuple, typing.Callable):
tvars = _type_vars(params)
args = params
elif self.__origin__ in (typing.Generic, Protocol):
raise TypeError(f"Cannot subscript already-subscripted {repr(self)}")
else:
_check_generic(self, params)
tvars = _type_vars(params)
args = params
prepend = (self,) if self.__origin__ is None else ()
return self.__class__(self.__name__,
prepend + self.__bases__,
_no_slots_copy(self.__dict__),
tvars=tvars,
args=args,
origin=self,
extra=self.__extra__,
orig_bases=self.__orig_bases__)
class Protocol(metaclass=_ProtocolMeta):
"""Base class for protocol classes. Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize
structural subtyping (static duck-typing), for example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with
@typing_extensions.runtime act as simple-minded runtime protocol that checks
only the presence of given attributes, ignoring their type signatures.
Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
"""
__slots__ = ()
_is_protocol = True
def __new__(cls, *args, **kwds):
if _gorg(cls) is Protocol:
raise TypeError("Type Protocol cannot be instantiated; "
"it can be used only as a base class")
return typing._generic_new(cls.__next_in_mro__, cls, *args, **kwds)
# 3.8+
if hasattr(typing, 'runtime_checkable'):
runtime_checkable = typing.runtime_checkable
# 3.6-3.7
else:
def runtime_checkable(cls):
"""Mark a protocol class as a runtime protocol, so that it
can be used with isinstance() and issubclass(). Raise TypeError
if applied to a non-protocol class.
This allows a simple-minded structural check very similar to the
one-offs in collections.abc such as Hashable.
"""
if not isinstance(cls, _ProtocolMeta) or not cls._is_protocol:
raise TypeError('@runtime_checkable can be only applied to protocol classes,'
f' got {cls!r}')
cls._is_runtime_protocol = True
return cls
# Exists for backwards compatibility.
runtime = runtime_checkable
# 3.8+
if hasattr(typing, 'SupportsIndex'):
SupportsIndex = typing.SupportsIndex
# 3.6-3.7
else:
@runtime_checkable
class SupportsIndex(Protocol):
__slots__ = ()
@abc.abstractmethod
def __index__(self) -> int:
pass
if sys.version_info >= (3, 9, 2):
# The standard library TypedDict in Python 3.8 does not store runtime information
# about which (if any) keys are optional. See https://bugs.python.org/issue38834
# The standard library TypedDict in Python 3.9.0/1 does not honour the "total"
# keyword with old-style TypedDict(). See https://bugs.python.org/issue42059
TypedDict = typing.TypedDict
else:
def _check_fails(cls, other):
try:
if sys._getframe(1).f_globals['__name__'] not in ['abc',
'functools',
'typing']:
# Typed dicts are only for static structural subtyping.
raise TypeError('TypedDict does not support instance and class checks')
except (AttributeError, ValueError):
pass
return False
def _dict_new(*args, **kwargs):
if not args:
raise TypeError('TypedDict.__new__(): not enough arguments')
_, args = args[0], args[1:] # allow the "cls" keyword be passed
return dict(*args, **kwargs)
_dict_new.__text_signature__ = '($cls, _typename, _fields=None, /, **kwargs)'
def _typeddict_new(*args, total=True, **kwargs):
if not args:
raise TypeError('TypedDict.__new__(): not enough arguments')
_, args = args[0], args[1:] # allow the "cls" keyword be passed
if args:
typename, args = args[0], args[1:] # allow the "_typename" keyword be passed
elif '_typename' in kwargs:
typename = kwargs.pop('_typename')
import warnings
warnings.warn("Passing '_typename' as keyword argument is deprecated",
DeprecationWarning, stacklevel=2)
else:
raise TypeError("TypedDict.__new__() missing 1 required positional "
"argument: '_typename'")
if args:
try:
fields, = args # allow the "_fields" keyword be passed
except ValueError:
raise TypeError('TypedDict.__new__() takes from 2 to 3 '
f'positional arguments but {len(args) + 2} '
'were given')
elif '_fields' in kwargs and len(kwargs) == 1:
fields = kwargs.pop('_fields')
import warnings
warnings.warn("Passing '_fields' as keyword argument is deprecated",
DeprecationWarning, stacklevel=2)
else:
fields = None
if fields is None:
fields = kwargs
elif kwargs:
raise TypeError("TypedDict takes either a dict or keyword arguments,"
" but not both")
ns = {'__annotations__': dict(fields)}
try:
# Setting correct module is necessary to make typed dict classes pickleable.
ns['__module__'] = sys._getframe(1).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError):
pass
return _TypedDictMeta(typename, (), ns, total=total)
_typeddict_new.__text_signature__ = ('($cls, _typename, _fields=None,'
' /, *, total=True, **kwargs)')
class _TypedDictMeta(type):
def __init__(cls, name, bases, ns, total=True):
super().__init__(name, bases, ns)
def __new__(cls, name, bases, ns, total=True):
# Create new typed dict class object.
# This method is called directly when TypedDict is subclassed,
# or via _typeddict_new when TypedDict is instantiated. This way
# TypedDict supports all three syntaxes described in its docstring.
# Subclasses and instances of TypedDict return actual dictionaries
# via _dict_new.
ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new
tp_dict = super().__new__(cls, name, (dict,), ns)
annotations = {}
own_annotations = ns.get('__annotations__', {})
own_annotation_keys = set(own_annotations.keys())
msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"
own_annotations = {
n: typing._type_check(tp, msg) for n, tp in own_annotations.items()
}
required_keys = set()
optional_keys = set()
for base in bases:
annotations.update(base.__dict__.get('__annotations__', {}))
required_keys.update(base.__dict__.get('__required_keys__', ()))
optional_keys.update(base.__dict__.get('__optional_keys__', ()))
annotations.update(own_annotations)
if total:
required_keys.update(own_annotation_keys)
else:
optional_keys.update(own_annotation_keys)
tp_dict.__annotations__ = annotations
tp_dict.__required_keys__ = frozenset(required_keys)
tp_dict.__optional_keys__ = frozenset(optional_keys)
if not hasattr(tp_dict, '__total__'):
tp_dict.__total__ = total
return tp_dict
__instancecheck__ = __subclasscheck__ = _check_fails
TypedDict = _TypedDictMeta('TypedDict', (dict,), {})
TypedDict.__module__ = __name__
TypedDict.__doc__ = \
"""A simple typed name space. At runtime it is equivalent to a plain dict.
TypedDict creates a dictionary type that expects all of its
instances to have a certain set of keys, with each key
associated with a value of a consistent type. This expectation
is not checked at runtime but is only enforced by type checkers.
Usage::
class Point2D(TypedDict):
x: int
y: int
label: str
a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK
b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check
assert Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first')
The type info can be accessed via the Point2D.__annotations__ dict, and
the Point2D.__required_keys__ and Point2D.__optional_keys__ frozensets.
TypedDict supports two additional equivalent forms::
Point2D = TypedDict('Point2D', x=int, y=int, label=str)
Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str})
The class syntax is only supported in Python 3.6+, while two other
syntax forms work for Python 2.7 and 3.2+
"""
# Python 3.9+ has PEP 593 (Annotated and modified get_type_hints)
if hasattr(typing, 'Annotated'):
Annotated = typing.Annotated
get_type_hints = typing.get_type_hints
# Not exported and not a public API, but needed for get_origin() and get_args()
# to work.
_AnnotatedAlias = typing._AnnotatedAlias
# 3.7-3.8
elif PEP_560:
class _AnnotatedAlias(typing._GenericAlias, _root=True):
"""Runtime representation of an annotated type.
At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't'
with extra annotations. The alias behaves like a normal typing alias,
instantiating is the same as instantiating the underlying type, binding
it to types is also the same.
"""
def __init__(self, origin, metadata):
if isinstance(origin, _AnnotatedAlias):
metadata = origin.__metadata__ + metadata
origin = origin.__origin__
super().__init__(origin, origin)
self.__metadata__ = metadata
def copy_with(self, params):
assert len(params) == 1
new_type = params[0]
return _AnnotatedAlias(new_type, self.__metadata__)
def __repr__(self):
return (f"typing_extensions.Annotated[{typing._type_repr(self.__origin__)}, "
f"{', '.join(repr(a) for a in self.__metadata__)}]")
def __reduce__(self):
return operator.getitem, (
Annotated, (self.__origin__,) + self.__metadata__
)
def __eq__(self, other):
if not isinstance(other, _AnnotatedAlias):
return NotImplemented
if self.__origin__ != other.__origin__:
return False
return self.__metadata__ == other.__metadata__
def __hash__(self):
return hash((self.__origin__, self.__metadata__))
class Annotated:
"""Add context specific metadata to a type.
Example: Annotated[int, runtime_check.Unsigned] indicates to the
hypothetical runtime_check module that this type is an unsigned int.
Every other consumer of this type can ignore this metadata and treat
this type as int.
The first argument to Annotated must be a valid type (and will be in
the __origin__ field), the remaining arguments are kept as a tuple in
the __extra__ field.
Details:
- It's an error to call `Annotated` with less than two arguments.
- Nested Annotated are flattened::
Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3]
- Instantiating an annotated type is equivalent to instantiating the
underlying type::
Annotated[C, Ann1](5) == C(5)
- Annotated can be used as a generic type alias::
Optimized = Annotated[T, runtime.Optimize()]
Optimized[int] == Annotated[int, runtime.Optimize()]
OptimizedList = Annotated[List[T], runtime.Optimize()]
OptimizedList[int] == Annotated[List[int], runtime.Optimize()]
"""
__slots__ = ()
def __new__(cls, *args, **kwargs):
raise TypeError("Type Annotated cannot be instantiated.")
@typing._tp_cache
def __class_getitem__(cls, params):
if not isinstance(params, tuple) or len(params) < 2:
raise TypeError("Annotated[...] should be used "
"with at least two arguments (a type and an "
"annotation).")
msg = "Annotated[t, ...]: t must be a type."
origin = typing._type_check(params[0], msg)
metadata = tuple(params[1:])
return _AnnotatedAlias(origin, metadata)
def __init_subclass__(cls, *args, **kwargs):
raise TypeError(
f"Cannot subclass {cls.__module__}.Annotated"
)
def _strip_annotations(t):
"""Strips the annotations from a given type.
"""
if isinstance(t, _AnnotatedAlias):
return _strip_annotations(t.__origin__)
if isinstance(t, typing._GenericAlias):
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
if stripped_args == t.__args__:
return t
res = t.copy_with(stripped_args)
res._special = t._special
return res
return t
def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
"""Return type hints for an object.
This is often the same as obj.__annotations__, but it handles
forward references encoded as string literals, adds Optional[t] if a
default value equal to None is set and recursively replaces all
'Annotated[T, ...]' with 'T' (unless 'include_extras=True').
The argument may be a module, class, method, or function. The annotations
are returned as a dictionary. For classes, annotations include also
inherited members.
TypeError is raised if the argument is not of a type that can contain
annotations, and an empty dictionary is returned if no annotations are
present.
BEWARE -- the behavior of globalns and localns is counterintuitive
(unless you are familiar with how eval() and exec() work). The
search order is locals first, then globals.
- If no dict arguments are passed, an attempt is made to use the
globals from obj (or the respective module's globals for classes),
and these are also used as the locals. If the object does not appear
to have globals, an empty dictionary is used.
- If one dict argument is passed, it is used for both globals and
locals.
- If two dict arguments are passed, they specify globals and
locals, respectively.
"""
hint = typing.get_type_hints(obj, globalns=globalns, localns=localns)
if include_extras:
return hint
return {k: _strip_annotations(t) for k, t in hint.items()}
# 3.6
else:
def _is_dunder(name):
"""Returns True if name is a __dunder_variable_name__."""
return len(name) > 4 and name.startswith('__') and name.endswith('__')
# Prior to Python 3.7 types did not have `copy_with`. A lot of the equality
# checks, argument expansion etc. are done on the _subs_tre. As a result we
# can't provide a get_type_hints function that strips out annotations.
class AnnotatedMeta(typing.GenericMeta):
"""Metaclass for Annotated"""
def __new__(cls, name, bases, namespace, **kwargs):
if any(b is not object for b in bases):
raise TypeError("Cannot subclass " + str(Annotated))
return super().__new__(cls, name, bases, namespace, **kwargs)
@property
def __metadata__(self):
return self._subs_tree()[2]
def _tree_repr(self, tree):
cls, origin, metadata = tree
if not isinstance(origin, tuple):
tp_repr = typing._type_repr(origin)
else:
tp_repr = origin[0]._tree_repr(origin)
metadata_reprs = ", ".join(repr(arg) for arg in metadata)
return f'{cls}[{tp_repr}, {metadata_reprs}]'
def _subs_tree(self, tvars=None, args=None): # noqa
if self is Annotated:
return Annotated
res = super()._subs_tree(tvars=tvars, args=args)
# Flatten nested Annotated
if isinstance(res[1], tuple) and res[1][0] is Annotated:
sub_tp = res[1][1]
sub_annot = res[1][2]
return (Annotated, sub_tp, sub_annot + res[2])
return res
def _get_cons(self):
"""Return the class used to create instance of this type."""
if self.__origin__ is None:
raise TypeError("Cannot get the underlying type of a "
"non-specialized Annotated type.")
tree = self._subs_tree()
while isinstance(tree, tuple) and tree[0] is Annotated:
tree = tree[1]
if isinstance(tree, tuple):
return tree[0]
else:
return tree
@typing._tp_cache
def __getitem__(self, params):
if not isinstance(params, tuple):
params = (params,)
if self.__origin__ is not None: # specializing an instantiated type
return super().__getitem__(params)
elif not isinstance(params, tuple) or len(params) < 2:
raise TypeError("Annotated[...] should be instantiated "
"with at least two arguments (a type and an "
"annotation).")
else:
msg = "Annotated[t, ...]: t must be a type."
tp = typing._type_check(params[0], msg)
metadata = tuple(params[1:])
return self.__class__(
self.__name__,
self.__bases__,
_no_slots_copy(self.__dict__),
tvars=_type_vars((tp,)),
# Metadata is a tuple so it won't be touched by _replace_args et al.
args=(tp, metadata),
origin=self,
)
def __call__(self, *args, **kwargs):
cons = self._get_cons()
result = cons(*args, **kwargs)
try:
result.__orig_class__ = self
except AttributeError:
pass
return result
def __getattr__(self, attr):
# For simplicity we just don't relay all dunder names
if self.__origin__ is not None and not _is_dunder(attr):
return getattr(self._get_cons(), attr)
raise AttributeError(attr)
def __setattr__(self, attr, value):
if _is_dunder(attr) or attr.startswith('_abc_'):
super().__setattr__(attr, value)
elif self.__origin__ is None:
raise AttributeError(attr)
else:
setattr(self._get_cons(), attr, value)
def __instancecheck__(self, obj):
raise TypeError("Annotated cannot be used with isinstance().")
def __subclasscheck__(self, cls):
raise TypeError("Annotated cannot be used with issubclass().")
class Annotated(metaclass=AnnotatedMeta):
"""Add context specific metadata to a type.
Example: Annotated[int, runtime_check.Unsigned] indicates to the
hypothetical runtime_check module that this type is an unsigned int.
Every other consumer of this type can ignore this metadata and treat
this type as int.
The first argument to Annotated must be a valid type, the remaining
arguments are kept as a tuple in the __metadata__ field.
Details:
- It's an error to call `Annotated` with less than two arguments.
- Nested Annotated are flattened::
Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3]
- Instantiating an annotated type is equivalent to instantiating the
underlying type::
Annotated[C, Ann1](5) == C(5)
- Annotated can be used as a generic type alias::
Optimized = Annotated[T, runtime.Optimize()]
Optimized[int] == Annotated[int, runtime.Optimize()]
OptimizedList = Annotated[List[T], runtime.Optimize()]
OptimizedList[int] == Annotated[List[int], runtime.Optimize()]
"""
# Python 3.8 has get_origin() and get_args() but those implementations aren't
# Annotated-aware, so we can't use those. Python 3.9's versions don't support
# ParamSpecArgs and ParamSpecKwargs, so only Python 3.10's versions will do.
if sys.version_info[:2] >= (3, 10):
get_origin = typing.get_origin
get_args = typing.get_args
# 3.7-3.9
elif PEP_560:
try:
# 3.9+
from typing import _BaseGenericAlias
except ImportError:
_BaseGenericAlias = typing._GenericAlias
try:
# 3.9+
from typing import GenericAlias
except ImportError:
GenericAlias = typing._GenericAlias
def get_origin(tp):
"""Get the unsubscripted version of a type.
This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar
and Annotated. Return None for unsupported types. Examples::
get_origin(Literal[42]) is Literal
get_origin(int) is None
get_origin(ClassVar[int]) is ClassVar
get_origin(Generic) is Generic
get_origin(Generic[T]) is Generic
get_origin(Union[T, int]) is Union
get_origin(List[Tuple[T, T]][int]) == list
get_origin(P.args) is P
"""
if isinstance(tp, _AnnotatedAlias):
return Annotated
if isinstance(tp, (typing._GenericAlias, GenericAlias, _BaseGenericAlias,
ParamSpecArgs, ParamSpecKwargs)):
return tp.__origin__
if tp is typing.Generic:
return typing.Generic
return None
def get_args(tp):
"""Get type arguments with all substitutions performed.
For unions, basic simplifications used by Union constructor are performed.
Examples::
get_args(Dict[str, int]) == (str, int)
get_args(int) == ()
get_args(Union[int, Union[T, int], str][int]) == (int, str)
get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
get_args(Callable[[], T][int]) == ([], int)
"""
if isinstance(tp, _AnnotatedAlias):
return (tp.__origin__,) + tp.__metadata__
if isinstance(tp, (typing._GenericAlias, GenericAlias)):
if getattr(tp, "_special", False):
return ()
res = tp.__args__
if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis:
res = (list(res[:-1]), res[-1])
return res
return ()
# 3.10+
if hasattr(typing, 'TypeAlias'):
TypeAlias = typing.TypeAlias
# 3.9
elif sys.version_info[:2] >= (3, 9):
class _TypeAliasForm(typing._SpecialForm, _root=True):
def __repr__(self):
return 'typing_extensions.' + self._name
@_TypeAliasForm
def TypeAlias(self, parameters):
"""Special marker indicating that an assignment should
be recognized as a proper type alias definition by type
checkers.
For example::
Predicate: TypeAlias = Callable[..., bool]
It's invalid when used anywhere except as in the example above.
"""
raise TypeError(f"{self} is not subscriptable")
# 3.7-3.8
elif sys.version_info[:2] >= (3, 7):
class _TypeAliasForm(typing._SpecialForm, _root=True):
def __repr__(self):
return 'typing_extensions.' + self._name
TypeAlias = _TypeAliasForm('TypeAlias',
doc="""Special marker indicating that an assignment should
be recognized as a proper type alias definition by type
checkers.
For example::
Predicate: TypeAlias = Callable[..., bool]
It's invalid when used anywhere except as in the example
above.""")
# 3.6
else:
class _TypeAliasMeta(typing.TypingMeta):
"""Metaclass for TypeAlias"""
def __repr__(self):
return 'typing_extensions.TypeAlias'
class _TypeAliasBase(typing._FinalTypingBase, metaclass=_TypeAliasMeta, _root=True):
"""Special marker indicating that an assignment should
be recognized as a proper type alias definition by type
checkers.
For example::
Predicate: TypeAlias = Callable[..., bool]
It's invalid when used anywhere except as in the example above.
"""
__slots__ = ()
def __instancecheck__(self, obj):
raise TypeError("TypeAlias cannot be used with isinstance().")
def __subclasscheck__(self, cls):
raise TypeError("TypeAlias cannot be used with issubclass().")
def __repr__(self):
return 'typing_extensions.TypeAlias'
TypeAlias = _TypeAliasBase(_root=True)
# Python 3.10+ has PEP 612
if hasattr(typing, 'ParamSpecArgs'):
ParamSpecArgs = typing.ParamSpecArgs
ParamSpecKwargs = typing.ParamSpecKwargs
# 3.6-3.9
else:
class _Immutable:
"""Mixin to indicate that object should not be copied."""
__slots__ = ()
def __copy__(self):
return self
def __deepcopy__(self, memo):
return self
class ParamSpecArgs(_Immutable):
"""The args for a ParamSpec object.
Given a ParamSpec object P, P.args is an instance of ParamSpecArgs.
ParamSpecArgs objects have a reference back to their ParamSpec:
P.args.__origin__ is P
This type is meant for runtime introspection and has no special meaning to
static type checkers.
"""
def __init__(self, origin):
self.__origin__ = origin
def __repr__(self):
return f"{self.__origin__.__name__}.args"
class ParamSpecKwargs(_Immutable):
"""The kwargs for a ParamSpec object.
Given a ParamSpec object P, P.kwargs is an instance of ParamSpecKwargs.
ParamSpecKwargs objects have a reference back to their ParamSpec:
P.kwargs.__origin__ is P
This type is meant for runtime introspection and has no special meaning to
static type checkers.
"""
def __init__(self, origin):
self.__origin__ = origin
def __repr__(self):
return f"{self.__origin__.__name__}.kwargs"
# 3.10+
if hasattr(typing, 'ParamSpec'):
ParamSpec = typing.ParamSpec
# 3.6-3.9
else:
# Inherits from list as a workaround for Callable checks in Python < 3.9.2.
class ParamSpec(list):
"""Parameter specification variable.
Usage::
P = ParamSpec('P')
Parameter specification variables exist primarily for the benefit of static
type checkers. They are used to forward the parameter types of one
callable to another callable, a pattern commonly found in higher order
functions and decorators. They are only valid when used in ``Concatenate``,
or s the first argument to ``Callable``. In Python 3.10 and higher,
they are also supported in user-defined Generics at runtime.
See class Generic for more information on generic types. An
example for annotating a decorator::
T = TypeVar('T')
P = ParamSpec('P')
def add_logging(f: Callable[P, T]) -> Callable[P, T]:
'''A type-safe decorator to add logging to a function.'''
def inner(*args: P.args, **kwargs: P.kwargs) -> T:
logging.info(f'{f.__name__} was called')
return f(*args, **kwargs)
return inner
@add_logging
def add_two(x: float, y: float) -> float:
'''Add two numbers together.'''
return x + y
Parameter specification variables defined with covariant=True or
contravariant=True can be used to declare covariant or contravariant
generic types. These keyword arguments are valid, but their actual semantics
are yet to be decided. See PEP 612 for details.
Parameter specification variables can be introspected. e.g.:
P.__name__ == 'T'
P.__bound__ == None
P.__covariant__ == False
P.__contravariant__ == False
Note that only parameter specification variables defined in global scope can
be pickled.
"""
# Trick Generic __parameters__.
__class__ = typing.TypeVar
@property
def args(self):
return ParamSpecArgs(self)
@property
def kwargs(self):
return ParamSpecKwargs(self)
def __init__(self, name, *, bound=None, covariant=False, contravariant=False):
super().__init__([self])
self.__name__ = name
self.__covariant__ = bool(covariant)
self.__contravariant__ = bool(contravariant)
if bound:
self.__bound__ = typing._type_check(bound, 'Bound must be a type.')
else:
self.__bound__ = None
# for pickling:
try:
def_mod = sys._getframe(1).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError):
def_mod = None
if def_mod != 'typing_extensions':
self.__module__ = def_mod
def __repr__(self):
if self.__covariant__:
prefix = '+'
elif self.__contravariant__:
prefix = '-'
else:
prefix = '~'
return prefix + self.__name__
def __hash__(self):
return object.__hash__(self)
def __eq__(self, other):
return self is other
def __reduce__(self):
return self.__name__
# Hack to get typing._type_check to pass.
def __call__(self, *args, **kwargs):
pass
if not PEP_560:
# Only needed in 3.6.
def _get_type_vars(self, tvars):
if self not in tvars:
tvars.append(self)
# 3.6-3.9
if not hasattr(typing, 'Concatenate'):
# Inherits from list as a workaround for Callable checks in Python < 3.9.2.
class _ConcatenateGenericAlias(list):
# Trick Generic into looking into this for __parameters__.
if PEP_560:
__class__ = typing._GenericAlias
else:
__class__ = typing._TypingBase
# Flag in 3.8.
_special = False
# Attribute in 3.6 and earlier.
_gorg = typing.Generic
def __init__(self, origin, args):
super().__init__(args)
self.__origin__ = origin
self.__args__ = args
def __repr__(self):
_type_repr = typing._type_repr
return (f'{_type_repr(self.__origin__)}'
f'[{", ".join(_type_repr(arg) for arg in self.__args__)}]')
def __hash__(self):
return hash((self.__origin__, self.__args__))
# Hack to get typing._type_check to pass in Generic.
def __call__(self, *args, **kwargs):
pass
@property
def __parameters__(self):
return tuple(
tp for tp in self.__args__ if isinstance(tp, (typing.TypeVar, ParamSpec))
)
if not PEP_560:
# Only required in 3.6.
def _get_type_vars(self, tvars):
if self.__origin__ and self.__parameters__:
typing._get_type_vars(self.__parameters__, tvars)
# 3.6-3.9
@typing._tp_cache
def _concatenate_getitem(self, parameters):
if parameters == ():
raise TypeError("Cannot take a Concatenate of no types.")
if not isinstance(parameters, tuple):
parameters = (parameters,)
if not isinstance(parameters[-1], ParamSpec):
raise TypeError("The last parameter to Concatenate should be a "
"ParamSpec variable.")
msg = "Concatenate[arg, ...]: each arg must be a type."
parameters = tuple(typing._type_check(p, msg) for p in parameters)
return _ConcatenateGenericAlias(self, parameters)
# 3.10+
if hasattr(typing, 'Concatenate'):
Concatenate = typing.Concatenate
_ConcatenateGenericAlias = typing._ConcatenateGenericAlias # noqa
# 3.9
elif sys.version_info[:2] >= (3, 9):
@_TypeAliasForm
def Concatenate(self, parameters):
"""Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a
higher order function which adds, removes or transforms parameters of a
callable.
For example::
Callable[Concatenate[int, P], int]
See PEP 612 for detailed information.
"""
return _concatenate_getitem(self, parameters)
# 3.7-8
elif sys.version_info[:2] >= (3, 7):
class _ConcatenateForm(typing._SpecialForm, _root=True):
def __repr__(self):
return 'typing_extensions.' + self._name
def __getitem__(self, parameters):
return _concatenate_getitem(self, parameters)
Concatenate = _ConcatenateForm(
'Concatenate',
doc="""Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a
higher order function which adds, removes or transforms parameters of a
callable.
For example::
Callable[Concatenate[int, P], int]
See PEP 612 for detailed information.
""")
# 3.6
else:
class _ConcatenateAliasMeta(typing.TypingMeta):
"""Metaclass for Concatenate."""
def __repr__(self):
return 'typing_extensions.Concatenate'
class _ConcatenateAliasBase(typing._FinalTypingBase,
metaclass=_ConcatenateAliasMeta,
_root=True):
"""Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a
higher order function which adds, removes or transforms parameters of a
callable.
For example::
Callable[Concatenate[int, P], int]
See PEP 612 for detailed information.
"""
__slots__ = ()
def __instancecheck__(self, obj):
raise TypeError("Concatenate cannot be used with isinstance().")
def __subclasscheck__(self, cls):
raise TypeError("Concatenate cannot be used with issubclass().")
def __repr__(self):
return 'typing_extensions.Concatenate'
def __getitem__(self, parameters):
return _concatenate_getitem(self, parameters)
Concatenate = _ConcatenateAliasBase(_root=True)
# 3.10+
if hasattr(typing, 'TypeGuard'):
TypeGuard = typing.TypeGuard
# 3.9
elif sys.version_info[:2] >= (3, 9):
class _TypeGuardForm(typing._SpecialForm, _root=True):
def __repr__(self):
return 'typing_extensions.' + self._name
@_TypeGuardForm
def TypeGuard(self, parameters):
"""Special typing form used to annotate the return type of a user-defined
type guard function. ``TypeGuard`` only accepts a single type argument.
At runtime, functions marked this way should return a boolean.
``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static
type checkers to determine a more precise type of an expression within a
program's code flow. Usually type narrowing is done by analyzing
conditional code flow and applying the narrowing to a block of code. The
conditional expression here is sometimes referred to as a "type guard".
Sometimes it would be convenient to use a user-defined boolean function
as a type guard. Such a function should use ``TypeGuard[...]`` as its
return type to alert static type checkers to this intention.
Using ``-> TypeGuard`` tells the static type checker that for a given
function:
1. The return value is a boolean.
2. If the return value is ``True``, the type of its argument
is the type inside ``TypeGuard``.
For example::
def is_str(val: Union[str, float]):
# "isinstance" type guard
if isinstance(val, str):
# Type of ``val`` is narrowed to ``str``
...
else:
# Else, type of ``val`` is narrowed to ``float``.
...
Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower
form of ``TypeA`` (it can even be a wider form) and this may lead to
type-unsafe results. The main reason is to allow for things like
narrowing ``List[object]`` to ``List[str]`` even though the latter is not
a subtype of the former, since ``List`` is invariant. The responsibility of
writing type-safe type guards is left to the user.
``TypeGuard`` also works with type variables. For more information, see
PEP 647 (User-Defined Type Guards).
"""
item = typing._type_check(parameters, f'{self} accepts only single type.')
return typing._GenericAlias(self, (item,))
# 3.7-3.8
elif sys.version_info[:2] >= (3, 7):
class _TypeGuardForm(typing._SpecialForm, _root=True):
def __repr__(self):
return 'typing_extensions.' + self._name
def __getitem__(self, parameters):
item = typing._type_check(parameters,
f'{self._name} accepts only a single type')
return typing._GenericAlias(self, (item,))
TypeGuard = _TypeGuardForm(
'TypeGuard',
doc="""Special typing form used to annotate the return type of a user-defined
type guard function. ``TypeGuard`` only accepts a single type argument.
At runtime, functions marked this way should return a boolean.
``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static
type checkers to determine a more precise type of an expression within a
program's code flow. Usually type narrowing is done by analyzing
conditional code flow and applying the narrowing to a block of code. The
conditional expression here is sometimes referred to as a "type guard".
Sometimes it would be convenient to use a user-defined boolean function
as a type guard. Such a function should use ``TypeGuard[...]`` as its
return type to alert static type checkers to this intention.
Using ``-> TypeGuard`` tells the static type checker that for a given
function:
1. The return value is a boolean.
2. If the return value is ``True``, the type of its argument
is the type inside ``TypeGuard``.
For example::
def is_str(val: Union[str, float]):
# "isinstance" type guard
if isinstance(val, str):
# Type of ``val`` is narrowed to ``str``
...
else:
# Else, type of ``val`` is narrowed to ``float``.
...
Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower
form of ``TypeA`` (it can even be a wider form) and this may lead to
type-unsafe results. The main reason is to allow for things like
narrowing ``List[object]`` to ``List[str]`` even though the latter is not
a subtype of the former, since ``List`` is invariant. The responsibility of
writing type-safe type guards is left to the user.
``TypeGuard`` also works with type variables. For more information, see
PEP 647 (User-Defined Type Guards).
""")
# 3.6
else:
class _TypeGuard(typing._FinalTypingBase, _root=True):
"""Special typing form used to annotate the return type of a user-defined
type guard function. ``TypeGuard`` only accepts a single type argument.
At runtime, functions marked this way should return a boolean.
``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static
type checkers to determine a more precise type of an expression within a
program's code flow. Usually type narrowing is done by analyzing
conditional code flow and applying the narrowing to a block of code. The
conditional expression here is sometimes referred to as a "type guard".
Sometimes it would be convenient to use a user-defined boolean function
as a type guard. Such a function should use ``TypeGuard[...]`` as its
return type to alert static type checkers to this intention.
Using ``-> TypeGuard`` tells the static type checker that for a given
function:
1. The return value is a boolean.
2. If the return value is ``True``, the type of its argument
is the type inside ``TypeGuard``.
For example::
def is_str(val: Union[str, float]):
# "isinstance" type guard
if isinstance(val, str):
# Type of ``val`` is narrowed to ``str``
...
else:
# Else, type of ``val`` is narrowed to ``float``.
...
Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower
form of ``TypeA`` (it can even be a wider form) and this may lead to
type-unsafe results. The main reason is to allow for things like
narrowing ``List[object]`` to ``List[str]`` even though the latter is not
a subtype of the former, since ``List`` is invariant. The responsibility of
writing type-safe type guards is left to the user.
``TypeGuard`` also works with type variables. For more information, see
PEP 647 (User-Defined Type Guards).
"""
__slots__ = ('__type__',)
def __init__(self, tp=None, **kwds):
self.__type__ = tp
def __getitem__(self, item):
cls = type(self)
if self.__type__ is None:
return cls(typing._type_check(item,
f'{cls.__name__[1:]} accepts only a single type.'),
_root=True)
raise TypeError(f'{cls.__name__[1:]} cannot be further subscripted')
def _eval_type(self, globalns, localns):
new_tp = typing._eval_type(self.__type__, globalns, localns)
if new_tp == self.__type__:
return self
return type(self)(new_tp, _root=True)
def __repr__(self):
r = super().__repr__()
if self.__type__ is not None:
r += f'[{typing._type_repr(self.__type__)}]'
return r
def __hash__(self):
return hash((type(self).__name__, self.__type__))
def __eq__(self, other):
if not isinstance(other, _TypeGuard):
return NotImplemented
if self.__type__ is not None:
return self.__type__ == other.__type__
return self is other
TypeGuard = _TypeGuard(_root=True)
if hasattr(typing, "Self"):
Self = typing.Self
elif sys.version_info[:2] >= (3, 9):
class _SelfForm(typing._SpecialForm, _root=True):
def __repr__(self):
return 'typing_extensions.' + self._name
@_SelfForm
def Self(self, params):
"""Used to spell the type of "self" in classes.
Example::
from typing import Self
class ReturnsSelf:
def parse(self, data: bytes) -> Self:
...
return self
"""
raise TypeError(f"{self} is not subscriptable")
elif sys.version_info[:2] >= (3, 7):
class _SelfForm(typing._SpecialForm, _root=True):
def __repr__(self):
return 'typing_extensions.' + self._name
Self = _SelfForm(
"Self",
doc="""Used to spell the type of "self" in classes.
Example::
from typing import Self
class ReturnsSelf:
def parse(self, data: bytes) -> Self:
...
return self
"""
)
else:
class _Self(typing._FinalTypingBase, _root=True):
"""Used to spell the type of "self" in classes.
Example::
from typing import Self
class ReturnsSelf:
def parse(self, data: bytes) -> Self:
...
return self
"""
__slots__ = ()
def __instancecheck__(self, obj):
raise TypeError(f"{self} cannot be used with isinstance().")
def __subclasscheck__(self, cls):
raise TypeError(f"{self} cannot be used with issubclass().")
Self = _Self(_root=True)
if hasattr(typing, 'Required'):
Required = typing.Required
NotRequired = typing.NotRequired
elif sys.version_info[:2] >= (3, 9):
class _ExtensionsSpecialForm(typing._SpecialForm, _root=True):
def __repr__(self):
return 'typing_extensions.' + self._name
@_ExtensionsSpecialForm
def Required(self, parameters):
"""A special typing construct to mark a key of a total=False TypedDict
as required. For example:
class Movie(TypedDict, total=False):
title: Required[str]
year: int
m = Movie(
title='The Matrix', # typechecker error if key is omitted
year=1999,
)
There is no runtime checking that a required key is actually provided
when instantiating a related TypedDict.
"""
item = typing._type_check(parameters, f'{self._name} accepts only single type')
return typing._GenericAlias(self, (item,))
@_ExtensionsSpecialForm
def NotRequired(self, parameters):
"""A special typing construct to mark a key of a TypedDict as
potentially missing. For example:
class Movie(TypedDict):
title: str
year: NotRequired[int]
m = Movie(
title='The Matrix', # typechecker error if key is omitted
year=1999,
)
"""
item = typing._type_check(parameters, f'{self._name} accepts only single type')
return typing._GenericAlias(self, (item,))
elif sys.version_info[:2] >= (3, 7):
class _RequiredForm(typing._SpecialForm, _root=True):
def __repr__(self):
return 'typing_extensions.' + self._name
def __getitem__(self, parameters):
item = typing._type_check(parameters,
'{} accepts only single type'.format(self._name))
return typing._GenericAlias(self, (item,))
Required = _RequiredForm(
'Required',
doc="""A special typing construct to mark a key of a total=False TypedDict
as required. For example:
class Movie(TypedDict, total=False):
title: Required[str]
year: int
m = Movie(
title='The Matrix', # typechecker error if key is omitted
year=1999,
)
There is no runtime checking that a required key is actually provided
when instantiating a related TypedDict.
""")
NotRequired = _RequiredForm(
'NotRequired',
doc="""A special typing construct to mark a key of a TypedDict as
potentially missing. For example:
class Movie(TypedDict):
title: str
year: NotRequired[int]
m = Movie(
title='The Matrix', # typechecker error if key is omitted
year=1999,
)
""")
else:
# NOTE: Modeled after _Final's implementation when _FinalTypingBase available
class _MaybeRequired(typing._FinalTypingBase, _root=True):
__slots__ = ('__type__',)
def __init__(self, tp=None, **kwds):
self.__type__ = tp
def __getitem__(self, item):
cls = type(self)
if self.__type__ is None:
return cls(typing._type_check(item,
'{} accepts only single type.'.format(cls.__name__[1:])),
_root=True)
raise TypeError('{} cannot be further subscripted'
.format(cls.__name__[1:]))
def _eval_type(self, globalns, localns):
new_tp = typing._eval_type(self.__type__, globalns, localns)
if new_tp == self.__type__:
return self
return type(self)(new_tp, _root=True)
def __repr__(self):
r = super().__repr__()
if self.__type__ is not None:
r += '[{}]'.format(typing._type_repr(self.__type__))
return r
def __hash__(self):
return hash((type(self).__name__, self.__type__))
def __eq__(self, other):
if not isinstance(other, _Final):
return NotImplemented
if self.__type__ is not None:
return self.__type__ == other.__type__
return self is other
class _Required(_MaybeRequired, _root=True):
"""A special typing construct to mark a key of a total=False TypedDict
as required. For example:
class Movie(TypedDict, total=False):
title: Required[str]
year: int
m = Movie(
title='The Matrix', # typechecker error if key is omitted
year=1999,
)
There is no runtime checking that a required key is actually provided
when instantiating a related TypedDict.
"""
class _NotRequired(_MaybeRequired, _root=True):
"""A special typing construct to mark a key of a TypedDict as
potentially missing. For example:
class Movie(TypedDict):
title: str
year: NotRequired[int]
m = Movie(
title='The Matrix', # typechecker error if key is omitted
year=1999,
)
"""
Required = _Required(_root=True)
NotRequired = _NotRequired(_root=True)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment