python – numpy的矢量化基数排序 – 它可以击败np.sort吗?

Numpy没有yet有一个基数排序,所以我想知道是否有可能使用预先存在的numpy函数编写一个.到目前为止,我有以下,它确实有效,但比numpy的快速排序慢约10倍.

line profiler output

测试和基准测试:

a = np.random.randint(0,1e8,1e6)
assert(np.all(radix_sort(a) == np.sort(a))) 
%timeit np.sort(a)
%timeit radix_sort(a)

mask_b循环可以至少部分地被矢量化,在来自&的掩模之间广播,并且使用具有轴arg的cumsum,但是这最终是悲观化,可能是由于增加的存储器占用.

如果有人能够看到一种方法来改进我所拥有的东西,我会有兴趣听到,即使它仍然比np.sort慢……这更像是一种求知欲和对numpy技巧的兴趣.

请注意,can implement快速计数排序很容易,但这只与小整数数据有关.

编辑1:从循环中取出np.arange(n)会有所帮助,但这并不是非常令人兴奋.

编辑2:cumsum实际上是多余的(ooops!)但是这个更简单的版本只对性能有所帮助.

def radix_sort(a):
    bit_len = np.max(a).bit_length()
    n = len(a)
    cached_arange = arange(n)
    idx = np.empty(n,dtype=int) # fully overwritten each iteration
    for mask_b in xrange(bit_len):
        is_one = (a & 2**mask_b).astype(bool)
        n_ones = np.sum(is_one)      
        n_zeros = n-n_ones
        idx[~is_one] = cached_arange[:n_zeros]
        idx[is_one] = cached_arange[:n_ones] + n_zeros
        # next three lines just do: a[idx] = a,but correctly
        new_a = np.empty(n,dtype=a.dtype)
        new_a[idx] = a
        a = new_a
    return a

编辑3:如果您在多个步骤中构造idx,则可以一次循环两个或更多个,而不是循环使用单个位.使用2位有点帮助,我没有尝试过更多:

idx[is_zero] = np.arange(n_zeros)
idx[is_one] = np.arange(n_ones)
idx[is_two] = np.arange(n_twos)
idx[is_three] = np.arange(n_threes)

编辑4和5:对于我正在测试的输入,4位似乎是最好的.此外,你可以完全摆脱idx步骤.现在只比np.sort(source available as gist)慢5倍,而不是10倍:

enter image description here

编辑6:这是上面的一个整理版本,但它也有点慢. 80%的时间用于重复和提取 – 如果只有一种方法来广播提取物:( …

def radix_sort(a,batch_m_bits=3):
    bit_len = np.max(a).bit_length()
    batch_m = 2**batch_m_bits
    mask = 2**batch_m_bits - 1
    val_set = np.arange(batch_m,dtype=a.dtype)[:,nax] # nax = np.newaxis
    for _ in range((bit_len-1)//batch_m_bits + 1): # ceil-division
        a = np.extract((a & mask)[nax,:] == val_set,np.repeat(a[nax,:],batch_m,axis=0))
        val_set <<= batch_m_bits
        mask <<= batch_m_bits
    return a

编辑7& 8:实际上,您可以使用asprtrided从numpy.lib.stride_tricks广播提取,但它似乎没有太大的性能帮助:

enter image description here

最初这对我有意义,理由是提取将遍历整个数组batch_m次,因此CPU请求的高速缓存行总数将与之前相同(只是在它具有的过程结束时)请求每个缓存行batch_m次).然而,the reality是提取不够聪明,无法迭代任意阶梯数组,并且必须在开始之前扩展数组,即无论如何最终都要完成重复.
事实上,在查看了提取源之后,我现在看到我们用这种方法做的最好的事情是:

a = a[np.flatnonzero((a & mask)[nax,:] == val_set) % len(a)]

这比提取略慢.但是,如果len(a)是2的幂,我们可以用&更换昂贵的mod操作. (len(a) – 1),它最终比提取版本快一点(现在大约4.9x np.sort为a = randint(0,2 ** 20).我想我们可以做到这一点通过零填充来处理两个长度的非幂,然后在排序结束时裁剪额外的零…然而,除非长度已经接近2的幂,否则这将是一个悲观.

最佳答案
我和Numba一起去看看基数排序的速度有多快. Numba(通常)表现良好的关键是写出所有循环,这非常有启发性.我最终得到了以下内容:

from numba import jit

@jit
def radix_loop(nbatches,batch_m_bits,bitsums,a,out):
    mask = (1 << batch_m_bits) - 1
    for shift in range(0,nbatches*batch_m_bits,batch_m_bits):
        # set bit sums to zero
        for i in range(bitsums.shape[0]):
            bitsums[i] = 0

        # determine bit sums
        for i in range(a.shape[0]):
            j = (a[i] & mask) >> shift
            bitsums[j] += 1

        # take the cumsum of the bit sums
        cumsum = 0
        for i in range(bitsums.shape[0]):
            temp = bitsums[i]
            bitsums[i] = cumsum
            cumsum += temp

        # sorting loop
        for i in range(a.shape[0]):
            j = (a[i] & mask) >> shift
            out[bitsums[j]] = a[i]
            bitsums[j] += 1

        # prepare next iteration
        mask <<= batch_m_bits
        # cant use `temp` here because of numba internal types
        temp2 = a
        a = out
        out = temp2

    return a

从4个内圈开始,很容易看出它是第4个,因此很难用Numpy进行矢量化.

欺骗这个问题的一种方法是从Scipy:scipy.sparse.coo.coo_tocsr中引入一个特定的C函数.它与上面的Python函数几乎完全相同的内部循环,因此可以滥用在Python中编写更快的“矢量化”基数排序.也许是这样的:

from scipy.sparse.coo import coo_tocsr

def radix_step(radix,keys,w):
    coo_tocsr(radix,1,a.size,w,w)
    return w,a

def scipysparse_radix_perbyte(a):
    # coo_tocsr internally works with system int and upcasts
    # anything else. We need to copy anyway to not mess with
    # original array. Also take into account endianness...
    a = a.astype('

编辑:稍微优化一下功能..查看编辑历史记录.

如上所述的LSB基数排序的一个低效率是阵列在RAM中完全洗牌多次,这意味着CPU缓存使用得不是很好.为了尝试减轻这种影响,可以选择先使用MSB基数排序进行传递,将项目放在大致正确的RAM块中,然后使用LSB基数排序对每个结果组进行排序.这是一个实现:

def scipysparse_radix_hybrid(a,bbits=8,gbits=8):
    """
    Parameters
    ----------
    a : Array of non-negative integers to be sorted.
    bbits : Number of bits in radix for LSB sorting.
    gbits : Number of bits in radix for MSB grouping.
    """
    a = a.copy()
    bitlen = int(a.max()).bit_length()
    work = np.empty_like(a)

    # Group values by single iteration of MSB radix sort:
    # Casting to np.int_ to get rid of python BigInt
    ngroups = np.int_(2**gbits)
    group_offset = np.empty(ngroups + 1,int)
    shift = max(bitlen-gbits,0)
    a,work = radix_step(ngroups,a>>shift,group_offset,work)
    bitlen = shift
    if not bitlen:
        return a

    # LSB radix sort each group:
    agroups = np.split(a,group_offset[1:-1])
    # Mask off high bits to not undo the grouping..
    gmask = (1 << shift) - 1
    nbatch = (bitlen-1) // bbits + 1
    radix = np.int_(2**bbits)
    _ = np.empty(radix + 1,int)
    for agi in agroups:
        if not agi.size:
            continue
        mask = (radix - 1) & gmask
        wgi = work[:agi.size]
        for shift in range(0,nbatch*bbits,bbits):
            keys = (agi & mask) >> shift
            agi,wgi = radix_step(radix,agi,wgi)
            mask = (mask << bbits) & gmask
        if nbatch % 2:
            # Copy result back in to `a`
            wgi[...] = agi
    return a

计时(在我的系统上为每个设置提供最佳性能):

def numba_radix(a,batch_m_bits=8):
    a = a.copy()
    bit_len = int(a.max()).bit_length()
    nbatches = (bit_len-1)//batch_m_bits +1
    work = np.zeros_like(a)
    bitsums = np.zeros(2**batch_m_bits + 1,int)
    srtd = radix_loop(nbatches,work)
    return srtd

a = np.random.randint(0,1e6)
%timeit numba_radix(a,9)
# 10 loops,best of 3: 76.1 ms per loop
%timeit np.sort(a)
#10 loops,best of 3: 115 ms per loop
%timeit scipysparse_radix_perbyte(a)
#10 loops,best of 3: 95.2 ms per loop
%timeit scipysparse_radix_hybrid(a,11,6)
#10 loops,best of 3: 75.4 ms per loop

正如预期的那样,Numba的表现非常出色.而且通过一些巧妙应用现有的C-extension,可以击败numpy.sort. IMO在优化级别你已经得到了它的价值 – 它也考虑了Numpy的附加组件,但我不会真正考虑我的答案中的实现“矢量化”:大部分工作是在外部完成的专用功能.

令我印象深刻的另一件事是对基数选择的敏感性.对于我尝试的大多数设置,我的实现仍然比numpy.sort慢,所以在实践中需要某种启发式方法来提供全面的性能.

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


本文从多个角度分析了vi编辑器保存退出命令。我们介绍了保存和退出vi编辑器的命令,以及如何撤销更改、移动光标、查找和替换文本等实用命令。希望这些技巧能帮助你更好地使用vi编辑器。
Python中的回车和换行是计算机中文本处理中的两个重要概念,它们在代码编写中扮演着非常重要的角色。本文从多个角度分析了Python中的回车和换行,包括回车和换行的概念、使用方法、使用场景和注意事项。通过本文的介绍,读者可以更好地理解和掌握Python中的回车和换行,从而编写出更加高效和规范的Python代码。
SQL Server启动不了错误1067是一种比较常见的故障,主要原因是数据库服务启动失败、权限不足和数据库文件损坏等。要解决这个问题,我们需要检查服务日志、重启服务器、检查文件权限和恢复数据库文件等。在日常的数据库运维工作中,我们应该时刻关注数据库的运行状况,及时发现并解决问题,以确保数据库的正常运行。
信息模块是一种可重复使用的、可编程的、可扩展的、可维护的、可测试的、可重构的软件组件。信息模块的端接需要从接口设计、数据格式、消息传递、函数调用等方面进行考虑。信息模块的端接需要满足高内聚、低耦合的原则,以保证系统的可扩展性和可维护性。
本文从电脑配置、PyCharm版本、Java版本、配置文件以及程序冲突等多个角度分析了Win10启动不了PyCharm的可能原因,并提供了解决方法。
本文主要从多个角度分析了安装SQL Server 2012时可能出现的错误,并提供了解决方法。
Pycharm是一款非常优秀的Python集成开发环境,它可以让Python开发者更加高效地进行代码编写、调试和测试。在Pycharm中设置解释器非常简单,我们可以通过创建新项目、修改项目解释器、设置全局解释器等多种方式进行设置。
Python中有多种方法可以将字符串转换为整数,包括使用int()函数、try-except语句、正则表达式、map()函数、ord()函数和reduce()函数。在实际应用中,应根据具体情况选择最合适的方法。
本文介绍了导入CSV文件的多种方法,包括使用Excel、Python和R等工具。同时,还介绍了导入CSV文件时需要注意的一些细节和问题。CSV文件是数据处理和分析中不可或缺的一部分,希望本文能够对读者有所帮助。
mongodb是一种新型的数据库,它采用了面向文档的数据模型,具有灵活性、高性能和高可用性等优势。但是,mongodb也存在数据结构混乱、安全性和学习成本高等问题。
当Python运行不了时,我们应该从代码、Python环境、操作系统和硬件设备等多个角度来排查问题,并采取相应的解决措施。
Python列表是一种常见的数据类型,排序是列表操作中的一个重要部分。本文介绍了Python列表降序排序的方法,包括使用sort()函数、sorted()函数以及自定义函数进行排序。使用sort()函数可以简单方便地实现降序排序,但会改变原始列表的顺序;使用sorted()函数可以保留原始列表的顺序,但需要创建一个新的列表;使用自定义函数可以灵活地控制排序的方式,但需要编写额外的代码。
本文介绍了如何使用Python输入一段英文并统计其中的单词个数,从去除标点符号、忽略单词大小写、排除常用词汇等多个角度进行了分析。此外,还介绍了使用NLTK库进行单词统计的方法。
虚拟环境可以帮助我们在同一台机器上运行不同版本的Python、安装不同的Python包,并且不会相互影响。创建虚拟环境的命令是python3 -m venv myenv,进入虚拟环境的命令是source myenv/bin/activate,退出虚拟环境的命令是deactivate。在虚拟环境中可以使用pip安装包,也可以使用Python运行程序。
本文从XHR对象、fetch API和jQuery三个方面分析了JS获取响应状态的方法及其应用。以上三种方法都可以轻松地发送HTTP请求,并处理响应数据。
桌面的命令包括常见的操作命令、系统命令、批处理命令以及第三方应用程序提供的命令。我们可以通过鼠标右键点击桌面、创建快捷方式、创建批处理文件等方式来运用这些命令,从而更好地管理计算机,提高工作效率。
本文分析了应用程序闪退的多个原因,包括应用程序本身存在问题、手机或平板电脑系统问题、硬件问题、网络问题和其他原因。同时,本文提供了解决闪退问题的多种方式,包括更新或卸载重新下载应用程序、升级系统或进行修复、清理手机缓存、清理不必要的文件或者是更换电池等方式来解决、确保网络信号的稳定性、注意用户隐私和安全问题。
本文介绍了使用Python下载图片的多种方法,包括使用Python标准库urllib.request、第三方库requests、多线程和异步IO。这些方法在不同情况下都有它们的优缺点。使用这些方法,我们可以轻松地将网络上的图片下载到本地,方便我们在离线状态下查看或处理这些图片。
MySQL数据文件是指存储MySQL数据库中数据的文件,存储位置的选择对数据库的性能、可靠性和安全性都有着重要的影响。本文从存储位置的选择、存储设备的选择、存储空间的管理和存储位置的安全性等多个角度对MySQL数据文件的存储位置进行分析,最后得出需要根据实际情况综合考虑多个因素,选择合适的存储位置和存储设备,并进行有效的存储空间管理和安全措施的结论。
AS400是一种主机操作系统,每个库都包含多个表。查询库表总数是一项基本任务。可以使用命令行、系统管理界面以及数据库管理工具来查询库表总数。查询库表总数可以帮助用户更好地管理和优化数据,包括规划数据存储、优化查询性能以及管理空间资源。