首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将numpy数组写入lmdb

将numpy数组写入lmdb
EN

Stack Overflow用户
提问于 2017-05-30 15:32:50
回答 1查看 3.6K关注 0票数 2

我试图在python中编写一些numpy数组到lmdb:

代码语言:javascript
复制
import numpy as np
import lmdb

def write_lmdb(filename):
    lmdb_env = lmdb.open(filename, map_size=int(1e9))
    lmdb_txn = lmdb_env.begin(write=True)

    X= np.array([[1.0, 0.0], [0.1, 2.0]])
    y= np.array([1.4, 2.1])

    #Put first pair of arrays
    lmdb_txn.put('X', X)
    lmdb_txn.put('y', y)

    #Put second pair of arrays
    lmdb_txn.put('X', X+1.6)
    lmdb_txn.put('y', y+1.2)

def read_lmdb(filename):
    lmdb_env = lmdb.open(filename)
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    for key, value in lmdb_cursor:
        print type(key)
        print type(value)

        print key
        print value

write_lmdb('temp.db')
read_lmdb('temp.db')

但是read_lmdb什么也不打印,向lmdb写入numpy数组的正确方法是什么?

更新:基于@frankyjuang答案的成功地做到了这一点,但是不是很优雅:多维数组失去了它的形状,每个数组都应该有自己的名称。

代码语言:javascript
复制
import numpy as np
import lmdb

def write_lmdb(filename):
    print 'Write lmdb'

    lmdb_env = lmdb.open(filename, map_size=int(1e9))

    n_samples= 2
    X= (255*np.random.rand(n_samples,3,4,3)).astype(np.uint8)
    y= np.random.rand(n_samples).astype(np.float32)

    for i in range(n_samples):
        with lmdb_env.begin(write=True) as lmdb_txn:
            lmdb_txn.put('X_'+str(i), X)
            lmdb_txn.put('y_'+str(i), y)

            print 'X:',X
            print 'y:',y

def read_lmdb(filename):
    print 'Read lmdb'

    lmdb_env = lmdb.open(filename)
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()

    n_samples=0
    with lmdb_env.begin() as lmdb_txn:
        with lmdb_txn.cursor() as lmdb_cursor:
            for key, value in lmdb_cursor:  
                print key
                if('X' in key):
                    print np.fromstring(value, dtype=np.uint8)
                if('y' in key):
                    print np.fromstring(value, dtype=np.float32)

                n_samples=n_samples+1

    print 'n_samples',n_samples

write_lmdb('temp.db')
read_lmdb('temp.db')

测试脚本输出应该如下所示:

代码语言:javascript
复制
Write lmdb
X: [[[[ 48 224 119]
   [ 76  87 174]
   [ 14  88 183]
   [ 76 234  56]]

  [[234 223  65]
   [ 63  85 175]
   [184 252 125]
   [100   7 225]]

  [[134 159  41]
   [  2 146 221]
   [ 99  74 225]
   [169  57  59]]]


 [[[100 202   3]
   [ 88 204 131]
   [ 96 238 243]
   [103  58  30]]

  [[157 125 107]
   [238 207  99]
   [102 220  64]
   [ 27 240  33]]

  [[ 74  93 131]
   [107  88 206]
   [ 55  86  35]
   [212 235 187]]]]
y: [ 0.80826157  0.01407595]
X: [[[[ 48 224 119]
   [ 76  87 174]
   [ 14  88 183]
   [ 76 234  56]]

  [[234 223  65]
   [ 63  85 175]
   [184 252 125]
   [100   7 225]]

  [[134 159  41]
   [  2 146 221]
   [ 99  74 225]
   [169  57  59]]]


 [[[100 202   3]
   [ 88 204 131]
   [ 96 238 243]
   [103  58  30]]

  [[157 125 107]
   [238 207  99]
   [102 220  64]
   [ 27 240  33]]

  [[ 74  93 131]
   [107  88 206]
   [ 55  86  35]
   [212 235 187]]]]
y: [ 0.80826157  0.01407595]
Read lmdb
X_0
[ 48 224 119  76  87 174  14  88 183  76 234  56 234 223  65  63  85 175
 184 252 125 100   7 225 134 159  41   2 146 221  99  74 225 169  57  59
 100 202   3  88 204 131  96 238 243 103  58  30 157 125 107 238 207  99
 102 220  64  27 240  33  74  93 131 107  88 206  55  86  35 212 235 187]
X_1
[ 48 224 119  76  87 174  14  88 183  76 234  56 234 223  65  63  85 175
 184 252 125 100   7 225 134 159  41   2 146 221  99  74 225 169  57  59
 100 202   3  88 204 131  96 238 243 103  58  30 157 125 107 238 207  99
 102 220  64  27 240  33  74  93 131 107  88 206  55  86  35 212 235 187]
y_0
[ 0.80826157  0.01407595]
y_1
[ 0.80826157  0.01407595]
n_samples 4
EN

回答 1

Stack Overflow用户

发布于 2017-05-30 18:08:28

将您的事务包装在with下。并记住使用np.fromstring将值从字节(字符串)转换回numpy数组。

老实说,在lmdb中存储numpy数组不是一个好主意,因为从数组到字节的转换回数组会丢失一些信息(例如。形状)。您可以尝试使用泡菜存储一组numpy数组。

代码语言:javascript
复制
def write_lmdb(filename):
    ...
    with lmdb_env.begin(write=True) as lmdb_txn:
        ...

def read_lmdb(filename):
    ...
    with lmdb_env.begin() as lmdb_txn:
        with lmdb_txn.cursor() as lmdb_cursor:
            ...
            print np.fromstring(value, dtype=np.float64)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/44266384

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档