漂亮格式的numpy JSON编码器

时间:2019-08-10 18:08:56

标签: python json numpy pprint

我一直在寻找一种使用numpy保存json数据的方法,同时保留numpy易于阅读的漂亮打印格式

this answer的启发,我选择使用pprint而不是base64以所需的格式写数据,因此给出:

import numpy as np
data = np.random.random((1,3,2))

磁盘上生成的文件应类似于:

{
    "__dtype__": "float64", 
    "__ndarray__": [[[0.7672818918130646 , 0.6846412220229668 ],
                     [0.7082023466738064 , 0.0896531267221291 ],
                     [0.43898454934160147, 0.9245898883694668 ]]]
}

出现了一些打

  • 虽然json可以读入格式为[[...]]的列表列表,但是numpy的浮点格式存在问题。例如,[[0., 0., 0.]]读回时会产生错误,而[[0.0, 0.0, 0.0]]会很好。

  • pformat将输出array([[0., 0., 0.]]),其中必须array()进行解析,否则json在读回数据时会引发错误。

    < / li>

要解决这些问题,我必须进行一些字符串解析,这导致我的当前代码如下:

import json, sys
import numpy as np
import pprint as pp

# Set numpy's printoptions to display all the data with max precision
np.set_printoptions(threshold=np.inf,
                    linewidth=sys.maxsize,
                    suppress=True,
                    nanstr='0.0',
                    infstr='0.0', 
                    precision=np.finfo(np.longdouble).precision)     



# Modified version of Adam Hughes's https://stackoverflow.com/a/27948073/1429402
def save_formatted(fname,data):

    class NumpyEncoder(json.JSONEncoder):
        def default(self, obj):
            if isinstance(obj, np.ndarray):
                return {'__ndarray__': self.numpy_to_string(obj),
                        '__dtype__': str(obj.dtype)}            

            return json.JSONEncoder.default(self, obj)


        def numpy_to_string(self,data):
            ''' Use pprint to generate a nicely formatted string
            '''

            # Get rid of array(...) and keep only [[...]]
            f = pp.pformat(data, width=sys.maxsize)
            f = f[6:-1].splitlines() # get rid of array(...) and keep only [[...]]

            # Remove identation caused by printing "array(" 
            for i in xrange(1,len(f)):
                f[i] = f[i][6:]

            return '\n'.join(f)


    # Parse json stream and fix formatting.
    # JSON doesn't support float arrays written as [0., 0., 0.]
    # so we look for the problematic numpy print syntax and correct
    # it to be readable natively by JSON, in this case: [0.0, 0.0, 0.0]
    with open(fname,'w') as io:
        for line in json.dumps(data, sort_keys=False, indent=4, cls=NumpyEncoder).splitlines():
            if '"__ndarray__": "' in line:
                index = line.index('"__ndarray__": "')
                lines = line.split('"__ndarray__": "')[-1][:-1]
                lines = lines.replace('. ','.0')  # convert occurences of ". " to ".0"    ex: 3. , 2. ]
                lines = lines.replace('.,','.0,') # convert occurences of ".," to ".0,"   ex: 3., 2.,
                lines = lines.replace('.]','.0]') # convert occurences of ".]" to ".0],"  ex: 3., 2.]
                lines = lines.split('\\n')

                # write each lines with appropriate indentation
                for i in xrange(len(lines)):
                    if i == 0:
                        indent = ' '*index
                        io.write(('%s"__ndarray__": %s\n"'%(indent,lines[i]))[:-1]) 
                    else:
                        indent = ' '*(index+len('"__ndarray__": "')-1)
                        io.write('%s%s\n'%(indent,lines[i]))                        

            else:
                io.write('%s\n'%line)



def load_formatted(fname):

    def json_numpy_obj_hook(dct):
        if isinstance(dct, dict) and '__ndarray__' in dct:
            return np.array(dct['__ndarray__']).astype(dct['__dtype__'])        
        return dct

    with open(fname,'r') as io:
        return json.load(io, object_hook=json_numpy_obj_hook)

要测试:

data = np.random.random((200,3,1000))
save_formatted('test.data', data)
data_ = load_formatted('test.data')

print np.allclose(data,data_) # Returns True

问题

我的解决方案很适合我,但是它的字符串解析方面使它在处理大型数据数组时很慢。是否会有更好的方法来达到预期的效果? regular expression可以代替我的序列str.replace()调用吗?或者,也许pprint首先可以用来正确格式化我的字符串?有没有更好的方法来制作jsonnumpy的打印格式这样的列表?

1 个答案:

答案 0 :(得分:0)

我无法给出具体的指示,但是我相信您最好的选择是找到一些开源的漂亮打印库,并用numpy的使用规则进行调整(numpy也是开源的,因此“逆向工程”)。

感谢How to prettyprint a JSON file?的一个示例:https://github.com/andy-gh/pygrid/blob/master/prettyjson.py(不一定是很好的例子,但是很好地说明了prettyprinter的大小并不大)。

我的信心在于,仅吐出所有这些元素和它们之间的间隙比在另一台漂亮打印机的结果上使用replace(我在您的代码中看到的)要快得多。

如果例程可以用cython重写,那就更好了。

如果您对解析感兴趣,它使用的ijson和库可以提供流式json的迭代解析,如果您的json不适合RAM,则可能会有帮助。