Pythonで行列演算

[pukiwiki]
Pythonの画像ライブラリ[[PIL>ググる:Python PIL]]をつかってて困るのが、太い線を描けないことと、スプライン曲線の描画関数が無いこと。
Pythonで拡張できるCGソフト(BlenderやInkscapeなど)を使うって方法もありますけれども、今回は自前でチャレンジしてみることに。

数回に分けて必要なプログラムをメモしますー

まずは行列演算
[/pukiwiki]

[pukiwiki]
以前、Python Recipeで見たような記憶があるのですが、よくわからなかったので自前で。

普段は[[NumPy>ググる:NumPy Python]]を使ってるのですが、配布のときに「NumPy必須」ってのも困ることもあるかも、ということで、スプライン補間の計算に必要な行列の積(ドット?内積?)だけ実装することに。今回は計算量もたかが知れてますし。

◎ m x n の、2次元行列式の積のみ。
3次元以上の配列には対応してません。

printデバッグがいっぱい入っててかっこ悪いです。すんません

テストとして、ランダムな行列を作成、NumPyでの計算結果と等しいかチェックしています。

[/pukiwiki]
*mat.py
ファイル名 mat.pyとして保存してください

""" mat.py"""
def flatten(m):
    f=[]
    for x in m:
        f.extend(x)
    return f
class Mat(object):
    debug=False
    def __init__(self,lst=None,shape=None):
        debug=self.debug
        if lst and shape :
            if debug :print "*1"
            r,c=shape
            if len(lst)!= r*c: raise "shape not match"
            self.lst=lst[:]
            self.shape=shape[:]           
        elif not lst and shape :
            if debug :print "*2"
            r,c=shape
            self.lst=range(r*c)
            
            if debug :print "%s (%s,%s)"%(self.lst,r,c)
            self.shape=shape[:]
        elif lst and not shape :
            if debug :print "*3"

            self.shape=(len(lst),1)
            self.lst=lst[:]
        else :
            if debug :print "*4"            
            self.lst=[]
            self.shape=[0,0]
        self.rows,self.cols=self.shape
    def __eq__(self,b):
        try :
            return  self.shape==b.shape and self.lst==b.lst 

        except :
            return False
    def reshape( self,*shape):
        #print "@reshape"
        if not shape : raise "reshape error"
        c=self.copy()
        rows,cols=self.shape
        l=rows*cols
        if l!=len(self.lst) : raise "shape not match"
        c.shape=shape[:]
        return c

    def append(self,x):
        if not len(x) != self.cols :
            raise "cols not match"
        self.lst.append(x)
        self.rows+=1
        self.shape=(self.rows,self.cols)
    def copy(self):
        return Mat(self.lst[:],self.shape[:])
    def __iter__(self):
        return iter(self.lst)
    def __getitem__(self,idx):
        if type(idx)==int :
            #print "*"
            r=idx
            return Mat([self.lst[r*self.cols+c] for c in xrange(self.cols)])
        else :        
            r,c=idx
            return self.lst[r*self.cols+c]
    def _calc(self,b,fnc):
        a=self
        t=type(b)
        if t==int or t==float :
            c=Mat(shape=self.shape)
            
            for i in xrange(a.rows):
                for j in xrange(a.cols):
                    c[i,j]=fnc(a[i,j],b)
            return c            
        elif not isinstance(b,Mat):
            raise "add : type missmatch"
        if a.shape!=b.shape :
            raise "add : shape not match"
        c=Mat(shape=a.shape)
        for i in xrange(a.rows):
            for j in xrange(a.cols):
                c[i,j]=fnc(a[i,j],b[i,j])
        return c
    def __add__(self,b):
        return self._calc(b,lambda x,y : x+y)
    def __sub__(self,b):
        return self._calc(b,lambda x,y : x-y)
    def __mul__(self,b):
        return self._calc(b,lambda x,y : x*y)
    def __div__(self,b):
        return self._calc(b,lambda x,y : x/y)
        
        
    def dot(self,b):
        return dot(self,b)
        
    def __setitem__(self,idx,val):
        r,c=idx
        self.lst[r*self.cols+c]=val
    def __repr__(self):
        return repr(self.lst)

    def __len__(self):
        return len(self.lst)
    def _set_shape(self,sh):
        self.rows,self.cols=sh
    def _get_shape(self):
        return self.rows,self.cols
Mat.shape=property(Mat._get_shape,Mat._set_shape)
debug=False
def dot(a,b):
    
    rows_b,cols_b=b.shape
    rows_a,cols_a=a.shape
    if cols_a!=rows_b : raise "shape not match"
    cols_c=cols_b
    rows_c=rows_a
    c=Mat(shape=(rows_c,cols_c))
    for i in xrange(cols_b):
        
        for j in xrange(rows_a):
            v=0
            for u in xrange(cols_a):
                v_=a[j,u]*b[u,i]
                s=""
                s+= "a%s%s (%s)"%(j,u,a[j,u]) 
                s+=    "* b %s%s (%s)"%(u,i,b[u,i])
                s+=    "=(%s)"%(v_)
                if debug:
                    print s,
                v+=v_
            c[j,i]=v
            if debug:
                print "= %s "%v
        if debug: print
    if debug: print c
    return c
    
import random
from random import randint
def rndmat(shape=(0,0)):
    r,c=shape

    if not r : r=randint(1,3)
    if not c : c=randint(1,3)
    
    m=Mat(shape=(r,c))
    #print m,r,c
    
    for j in xrange(r) :
        for i in xrange(c):
            m[j,i]=float(randint(1,10))
    return m
array=Mat
if __name__=="__main__":
    
    import numpy as np 
    Mat.debug=False
    m=rndmat()
    print m,m.shape
    n=rndmat( ( m.shape[1],0) )

    print n,n.shape
    mn=dot(m,n)
    print "dot",mn,mn.shape
    
    a=np.array(m.lst).reshape(m.shape)
    b=np.array(n.lst).reshape(n.shape)
    print "a=",a
    print "b=",b



    ab=np.dot(a,b)
    print "dot(a,b)=",ab
    rows,cols=ab.shape
    #ab=list(ab.reshape(rows*cols))

    assert list(ab.reshape(rows*cols))==list(mn)
    

    if m.shape==n.shape :
        ab=a+b;rows,cols=ab.shape
        assert list((ab).reshape(rows*cols))==list(m+n)

        ab=a-b;rows,cols=ab.shape
        assert list((ab).reshape(rows*cols))==list(m-n)
        ab=a*b; rows,cols=ab.shape

        assert list((ab).reshape(rows*cols))==list(m*n)

        ab=a/b; rows,cols=ab.shape

        assert list((ab).reshape(rows*cols))==list(m*n)


    x=random.randint(0,100)
    c=a+x
    rows,cols=c.shape
    assert list(c.reshape(rows*cols))==list(m+x)

    c=a-x
    rows,cols=c.shape
    assert list(c.reshape(rows*cols))==list(m-x)

    c=a*x
    rows,cols=c.shape
    assert list(c.reshape(rows*cols))==list(m*x)

    c=a/x
    rows,cols=c.shape
    assert list(c.reshape(rows*cols))==list(m/x)

    aa=array(shape=(4,3))
    assert aa.shape==(4,3) and aa.rows==4 and aa.cols==3
    aa2=aa.reshape(3,4)
    assert aa2.shape==(3,4) and aa2.rows==3 and aa2.cols==4
    
    print "OK"


コメントを残す

メールアドレスが公開されることはありません。