沿特定轴相加/相乘张量和向量

时间:2018-12-12 17:28:12

标签: python numpy tensorflow

如何在tensorflow和/或numpy中有效地有效地实现以下功能?

add_along_axis(tensor=T, vector=v, axis=k)
     # T is a tensor of shape (N1,...,Nd) (unknown beforehand)
     # v is a vector with N components
     # k is an integer such that Nk=N
     S = T+v, summed along k
     return S

S是具有组件(N1,..,Nd)的{​​{1}}张量

请注意,S[i1,...,id]=T[i1,...,id] + v[ik]Nj的任意数目都可能恰好等于j≠k,因此标准广播不是一种选择。

示例:让NT = np.zeros( (3,3,3) )然后正确的输出应该是

v = [1,2,3]

在这里,可以通过分别编写f(T,v,1) = [[[1., 1., 1.], [[2., 2., 2.], [[3., 3., 3.], [1., 1., 1.], [2., 2., 2.], [3., 3., 3.], [1., 1., 1.]], [2., 2., 2.]], [3., 3., 3.]]] f(T,v,2) = [[[1., 1., 1.], [[1., 1., 1.], [[1., 1., 1.], [2., 2., 2.], [2., 2., 2.], [2., 2., 2.], [3., 3., 3.]], [3., 3., 3.]], [3., 3., 3.]]] f(T,v,3) = [[[1., 2., 3.], [[1., 2., 3.], [[1., 2., 3.], [1., 2., 3.], [1., 2., 3.], [1., 2., 3.], [1., 2., 3.]], [1., 2., 3.]], [1., 2., 3.]]] T+v[:,None,None]T+v[None,:,None]来实现目标行为。但是,在未预定义张量形状的情况下,我看不到这种方法如何工作。

2 个答案:

答案 0 :(得分:0)

通过执行以下列表理解,您可以为T的任意尺寸和任何轴import {Component, ElementRef, OnInit, ViewChild} from '@angular/core'; import {ActivatedRoute, Params} from "@angular/router"; import {FormControl, FormGroup, Validators} from "@angular/forms"; import {CategoriesService} from "../../shared/services/categories.service"; import {switchMap} from "rxjs/operators"; import {of} from "rxjs"; import {MaterialService} from "../../shared/classes/material.service"; import {Category} from "../../shared/interfaces"; @Component({ selector: 'app-categories-form', templateUrl: './categories-form.component.html', styleUrls: ['./categories-form.component.css'] }) export class CategoriesFormComponent implements OnInit { @ViewChild('input') inputRef: ElementRef form: FormGroup image: File imagePreview='' isNew= true category: Category constructor(private route: ActivatedRoute, private categoriesService: CategoriesService) { } ngOnInit() { this.form=new FormGroup({ name:new FormControl(null, Validators.required) }) this.form.disable() this.route.params .pipe( switchMap( (params: Params)=>{ if(params['id']){ this.isNew=false return this.categoriesService.getById(params['id']) } return of (null) } ) ) .subscribe( (category: Category)=>{ if(category){ this.category = category this.form.patchValue({ name: category.name }) this.imagePreview=category.imageSrc MaterialService.updateTextInputs() } this.form.enable() }, error=>MaterialService.toast(error.error.message) ) } triggerClick(){ this.inputRef.nativeElement.click() } onFileUpload(event: any){ const file = event.target.files[0] this.image = file const reader = new FileReader() reader.onload = () => { this.imagePreview = reader.result as string } reader.readAsDataURL(file) } onSubmit(){ let obs$ this.form.disable() if(this.isNew){ obs$ = this.categoriesService.create(this.form.value.name, this.image) } else { obs$ = this.categoriesService.update(this.category._id, this.form.value.name, this.image) } obs$.subscribe( category =>{ this.category = category MaterialService.toast('Изменения сохранены!') this.form.enable() }, error=>{ MaterialService.toast(error.error.message) this.form.enable() } ) } } 自动生成create(name: string, image?:File) : Observable<Category>{ const fd = new FormData() if(image){ fd.append('image', image.name) } fd.append('name', name) return this.http.post<Category>('/api/category', fd) } update(id: string, name: string, image?:File) : Observable<Category>{ const fd= new FormData() if(image){ fd.append('image', image.name) } fd.append('name', name) return this.http.patch<Category>(`/api/category/${id}`, fd) }

v[:,None,None]

k等效于def f(T,v,k): return T+v[[np.newaxis if i+1 != k else slice(None) for i in range(T.ndim) ]] ,而np.newaxis等效于None。结果与预期的一样:

slice(None)

答案 1 :(得分:0)

只需将T.ndim-k的单位长度尺寸附加到v和numpy的广播规则上即可:

def f(T, v, k):
    v = asarray(v)
    return T + v.reshape(v.shape + (1,)*(T.ndim-k))

请注意,您对k的定义比标准的numpy轴编号大一;您可以考虑将k减1并将其称为“轴”。